Auto-update: Thu Aug 8 23:57:50 PDT 2024

This commit is contained in:
sanj 2024-08-08 23:57:50 -07:00
parent 46c5db23de
commit 487807bab1
7 changed files with 153 additions and 122 deletions

View file

@ -26,7 +26,8 @@ Db = Database.load('sys')
# HOST = f"{API.BIND}:{API.PORT}" # HOST = f"{API.BIND}:{API.PORT}"
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost'] # LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count()) MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())

View file

@ -86,6 +86,7 @@ app.add_middleware(
allow_headers=['*'], allow_headers=['*'],
) )
class SimpleAPIKeyMiddleware(BaseHTTPMiddleware): class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
client_ip = ipaddress.ip_address(request.client.host) client_ip = ipaddress.ip_address(request.client.host)
@ -97,17 +98,36 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
if not any(client_ip in subnet for subnet in trusted_subnets): if not any(client_ip in subnet for subnet in trusted_subnets):
api_key_header = request.headers.get("Authorization") api_key_header = request.headers.get("Authorization")
api_key_query = request.query_params.get("api_key") api_key_query = request.query_params.get("api_key")
# Debug logging for API keys
debug(f"API.KEYS: {API.KEYS}")
if api_key_header: if api_key_header:
api_key_header = api_key_header.lower().split("bearer ")[-1] api_key_header = api_key_header.lower().split("bearer ")[-1]
debug(f"API key provided in header: {api_key_header}")
if api_key_query:
debug(f"API key provided in query: {api_key_query}")
if api_key_header not in API.KEYS and api_key_query not in API.KEYS: if api_key_header not in API.KEYS and api_key_query not in API.KEYS:
err(f"Invalid API key provided by a requester.") err(f"Invalid API key provided by a requester.")
if api_key_header:
debug(f"Invalid API key in header: {api_key_header}")
if api_key_query:
debug(f"Invalid API key in query: {api_key_query}")
return JSONResponse( return JSONResponse(
status_code=401, status_code=401,
content={"detail": "Invalid or missing API key"} content={"detail": "Invalid or missing API key"}
) )
else:
if api_key_header in API.KEYS:
debug(f"Valid API key provided in header: {api_key_header}")
if api_key_query in API.KEYS:
debug(f"Valid API key provided in query: {api_key_query}")
response = await call_next(request) response = await call_next(request)
return response return response
# Add the middleware to your FastAPI app # Add the middleware to your FastAPI app
app.add_middleware(SimpleAPIKeyMiddleware) app.add_middleware(SimpleAPIKeyMiddleware)

View file

@ -259,6 +259,10 @@ class DirConfig:
class Database: class Database:
@classmethod
def load(cls, config_name: str):
return cls(config_name)
def __init__(self, config_path: str): def __init__(self, config_path: str):
self.config = self.load_config(config_path) self.config = self.load_config(config_path)
self.pool_connections = {} self.pool_connections = {}

View file

@ -99,6 +99,7 @@ EXTENSIONS:
courtlistener: off courtlistener: off
macnotify: on macnotify: on
shellfish: on shellfish: on
url_shortener: off
TZ: 'America/Los_Angeles' TZ: 'America/Los_Angeles'

View file

@ -16,7 +16,7 @@ from folium.plugins import Fullscreen, MiniMap, MousePosition, Geocoder, Draw, M
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from dateutil.parser import parse as dateutil_parse from dateutil.parser import parse as dateutil_parse
from typing import Optional, List, Union from typing import Optional, List, Union
from sijapi import L, API, TZ, GEO from sijapi import L, API, Db, TZ, GEO
from sijapi.classes import Location from sijapi.classes import Location
from sijapi.utilities import haversine, assemble_journal_path from sijapi.utilities import haversine, assemble_journal_path
@ -146,7 +146,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
ORDER BY datetime DESC ORDER BY datetime DESC
''' '''
locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations") locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
debug(f"Range locations query returned: {locations}") debug(f"Range locations query returned: {locations}")
@ -163,7 +163,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
ORDER BY datetime DESC ORDER BY datetime DESC
LIMIT 1 LIMIT 1
''' '''
location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations") location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None))
debug(f"Fallback query returned: {location_data}") debug(f"Fallback query returned: {location_data}")
if location_data: if location_data:
locations = location_data locations = location_data
@ -196,33 +196,40 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
return location_objects if location_objects else [] return location_objects if location_objects else []
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]: async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
datetime = await dt(datetime) try:
datetime = await dt(datetime)
debug(f"Fetching last location before {datetime}")
debug(f"Fetching last location before {datetime}") query = '''
SELECT id, datetime,
query = ''' ST_X(ST_AsText(location)::geometry) AS longitude,
SELECT id, datetime, ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_X(ST_AsText(location)::geometry) AS longitude, ST_Z(ST_AsText(location)::geometry) AS elevation,
ST_Y(ST_AsText(location)::geometry) AS latitude, city, state, zip, street, country,
ST_Z(ST_AsText(location)::geometry) AS elevation, action
city, state, zip, street, country, FROM locations
action WHERE datetime < $1
FROM locations ORDER BY datetime DESC
WHERE datetime < $1 LIMIT 1
ORDER BY datetime DESC '''
LIMIT 1
''' location_data = await Db.execute_read(query, datetime.replace(tzinfo=None))
location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations") if location_data:
debug(f"Last location found: {location_data[0]}")
if location_data: return Location(**location_data[0])
debug(f"Last location found: {location_data[0]}") else:
return Location(**location_data[0]) debug("No location found before the specified datetime")
else: return None
debug("No location found before the specified datetime") except Exception as e:
error(f"Error fetching last location: {str(e)}")
return None return None
@gis.get("/map", response_class=HTMLResponse) @gis.get("/map", response_class=HTMLResponse)
async def generate_map_endpoint( async def generate_map_endpoint(
start_date: Optional[str] = Query(None), start_date: Optional[str] = Query(None),
@ -244,7 +251,7 @@ async def generate_map_endpoint(
async def get_date_range(): async def get_date_range():
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations" query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
row = await API.execute_read_query(query, table_name="locations") row = await Db.execute_read(query, table_name="locations")
if row and row[0]['min_date'] and row[0]['max_date']: if row and row[0]['min_date'] and row[0]['max_date']:
return row[0]['min_date'], row[0]['max_date'] return row[0]['min_date'], row[0]['max_date']
else: else:
@ -416,6 +423,7 @@ map.on(L.Draw.Event.CREATED, function (event) {
return m.get_root().render() return m.get_root().render()
async def post_location(location: Location): async def post_location(location: Location):
try: try:
context = location.context or {} context = location.context or {}
@ -438,14 +446,13 @@ async def post_location(location: Location):
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
''' '''
await API.execute_write_query( await Db.execute_write(
query, query,
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state, localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
location.zip, location.street, action, device_type, device_model, device_name, device_os, location.zip, location.street, action, device_type, device_model, device_name, device_os,
location.class_, location.type, location.name, location.display_name, location.class_, location.type, location.name, location.display_name,
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood, location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
location.suburb, location.county, location.country_code, location.country, location.suburb, location.county, location.country_code, location.country
table_name="locations"
) )
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")

View file

@ -420,97 +420,95 @@ if API.EXTENSIONS.courtlistener:
await cl_download_file(download_url, target_path, session) await cl_download_file(download_url, target_path, session)
debug(f"Downloaded {file_name} to {target_path}") debug(f"Downloaded {file_name} to {target_path}")
if API.EXTENSIONS.url_shortener:
@serve.get("/s", response_class=HTMLResponse) @serve.get("/s", response_class=HTMLResponse)
async def shortener_form(request: Request): async def shortener_form(request: Request):
return templates.TemplateResponse("shortener.html", {"request": request}) return templates.TemplateResponse("shortener.html", {"request": request})
@serve.post("/s") @serve.post("/s")
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)): async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
if custom_code: if custom_code:
if len(custom_code) != 3 or not custom_code.isalnum(): if len(custom_code) != 3 or not custom_code.isalnum():
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"}) return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls")
if existing:
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
short_code = custom_code
else:
chars = string.ascii_letters + string.digits
while True:
debug(f"FOUND THE ISSUE")
short_code = ''.join(random.choice(chars) for _ in range(3))
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls")
if not existing:
break
await API.execute_write_query(
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
short_code, long_url,
table_name="short_urls"
)
short_url = f"https://sij.ai/{short_code}"
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
@serve.get("/{short_code}")
async def redirect_short_url(short_code: str):
results = await API.execute_read_query(
'SELECT long_url FROM short_urls WHERE short_code = $1',
short_code,
table_name="short_urls"
)
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls") if not results:
if existing: raise HTTPException(status_code=404, detail="Short URL not found")
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
short_code = custom_code long_url = results[0].get('long_url')
else:
chars = string.ascii_letters + string.digits if not long_url:
while True: raise HTTPException(status_code=404, detail="Long URL not found")
debug(f"FOUND THE ISSUE")
short_code = ''.join(random.choice(chars) for _ in range(3)) # Increment click count (you may want to do this asynchronously)
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls") await API.execute_write_query(
if not existing: 'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)',
break short_code, datetime.now(),
table_name="click_logs"
await API.execute_write_query( )
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
short_code, long_url, return RedirectResponse(url=long_url)
table_name="short_urls"
)
short_url = f"https://sij.ai/{short_code}"
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
@serve.get("/{short_code}")
async def redirect_short_url(short_code: str):
results = await API.execute_read_query(
'SELECT long_url FROM short_urls WHERE short_code = $1',
short_code,
table_name="short_urls"
)
if not results:
raise HTTPException(status_code=404, detail="Short URL not found")
long_url = results[0].get('long_url') @serve.get("/analytics/{short_code}")
async def get_analytics(short_code: str):
if not long_url: url_info = await API.execute_read_query(
raise HTTPException(status_code=404, detail="Long URL not found") 'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
short_code,
# Increment click count (you may want to do this asynchronously) table_name="short_urls"
await API.execute_write_query( )
'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)', if not url_info:
short_code, datetime.now(), raise HTTPException(status_code=404, detail="Short URL not found")
table_name="click_logs"
) click_count = await API.execute_read_query(
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
return RedirectResponse(url=long_url) short_code,
table_name="click_logs"
)
@serve.get("/analytics/{short_code}")
async def get_analytics(short_code: str): clicks = await API.execute_read_query(
url_info = await API.execute_read_query( 'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1', short_code,
short_code, table_name="click_logs"
table_name="short_urls" )
)
if not url_info: return {
raise HTTPException(status_code=404, detail="Short URL not found") "short_code": short_code,
"long_url": url_info['long_url'],
click_count = await API.execute_read_query( "created_at": url_info['created_at'],
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1', "total_clicks": click_count,
short_code, "recent_clicks": [dict(click) for click in clicks]
table_name="click_logs" }
)
clicks = await API.execute_read_query(
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
short_code,
table_name="click_logs"
)
return {
"short_code": short_code,
"long_url": url_info['long_url'],
"created_at": url_info['created_at'],
"total_clicks": click_count,
"recent_clicks": [dict(click) for click in clicks]
}

View file

@ -8,7 +8,7 @@ import httpx
import socket import socket
from fastapi import APIRouter from fastapi import APIRouter
from tailscale import Tailscale from tailscale import Tailscale
from sijapi import L, API, TS_ID, SUBNET_BROADCAST from sijapi import L, API, TS_ID
sys = APIRouter(tags=["public", "trusted", "private"]) sys = APIRouter(tags=["public", "trusted", "private"])
logger = L.get_module_logger("health") logger = L.get_module_logger("health")
@ -36,7 +36,7 @@ def get_local_ip():
"""Get the server's local IP address.""" """Get the server's local IP address."""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try: try:
s.connect((f'{SUBNET_BROADCAST}', 1)) s.connect((f'{API.SUBNET_BROADCAST}', 1))
IP = s.getsockname()[0] IP = s.getsockname()[0]
except Exception: except Exception:
IP = '127.0.0.1' IP = '127.0.0.1'