Auto-update: Thu Aug 8 23:57:50 PDT 2024

This commit is contained in:
sanj 2024-08-08 23:57:50 -07:00
parent 46c5db23de
commit 487807bab1
7 changed files with 153 additions and 122 deletions

View file

@ -26,7 +26,8 @@ Db = Database.load('sys')
# HOST = f"{API.BIND}:{API.PORT}"
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())

View file

@ -86,6 +86,7 @@ app.add_middleware(
allow_headers=['*'],
)
class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
client_ip = ipaddress.ip_address(request.client.host)
@ -97,17 +98,36 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
if not any(client_ip in subnet for subnet in trusted_subnets):
api_key_header = request.headers.get("Authorization")
api_key_query = request.query_params.get("api_key")
# Debug logging for API keys
debug(f"API.KEYS: {API.KEYS}")
if api_key_header:
api_key_header = api_key_header.lower().split("bearer ")[-1]
debug(f"API key provided in header: {api_key_header}")
if api_key_query:
debug(f"API key provided in query: {api_key_query}")
if api_key_header not in API.KEYS and api_key_query not in API.KEYS:
err(f"Invalid API key provided by a requester.")
if api_key_header:
debug(f"Invalid API key in header: {api_key_header}")
if api_key_query:
debug(f"Invalid API key in query: {api_key_query}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid or missing API key"}
)
else:
if api_key_header in API.KEYS:
debug(f"Valid API key provided in header: {api_key_header}")
if api_key_query in API.KEYS:
debug(f"Valid API key provided in query: {api_key_query}")
response = await call_next(request)
return response
# Add the middleware to your FastAPI app
app.add_middleware(SimpleAPIKeyMiddleware)

View file

@ -259,6 +259,10 @@ class DirConfig:
class Database:
@classmethod
def load(cls, config_name: str):
return cls(config_name)
def __init__(self, config_path: str):
self.config = self.load_config(config_path)
self.pool_connections = {}

View file

@ -99,6 +99,7 @@ EXTENSIONS:
courtlistener: off
macnotify: on
shellfish: on
url_shortener: off
TZ: 'America/Los_Angeles'

View file

@ -16,7 +16,7 @@ from folium.plugins import Fullscreen, MiniMap, MousePosition, Geocoder, Draw, M
from zoneinfo import ZoneInfo
from dateutil.parser import parse as dateutil_parse
from typing import Optional, List, Union
from sijapi import L, API, TZ, GEO
from sijapi import L, API, Db, TZ, GEO
from sijapi.classes import Location
from sijapi.utilities import haversine, assemble_journal_path
@ -146,7 +146,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
ORDER BY datetime DESC
'''
locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations")
locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
debug(f"Range locations query returned: {locations}")
@ -163,7 +163,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
ORDER BY datetime DESC
LIMIT 1
'''
location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations")
location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None))
debug(f"Fallback query returned: {location_data}")
if location_data:
locations = location_data
@ -196,33 +196,40 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
return location_objects if location_objects else []
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
datetime = await dt(datetime)
try:
datetime = await dt(datetime)
debug(f"Fetching last location before {datetime}")
debug(f"Fetching last location before {datetime}")
query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street, country,
action
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
'''
query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street, country,
action
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
'''
location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations")
location_data = await Db.execute_read(query, datetime.replace(tzinfo=None))
if location_data:
debug(f"Last location found: {location_data[0]}")
return Location(**location_data[0])
else:
debug("No location found before the specified datetime")
if location_data:
debug(f"Last location found: {location_data[0]}")
return Location(**location_data[0])
else:
debug("No location found before the specified datetime")
return None
except Exception as e:
error(f"Error fetching last location: {str(e)}")
return None
@gis.get("/map", response_class=HTMLResponse)
async def generate_map_endpoint(
start_date: Optional[str] = Query(None),
@ -244,7 +251,7 @@ async def generate_map_endpoint(
async def get_date_range():
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
row = await API.execute_read_query(query, table_name="locations")
row = await Db.execute_read(query, table_name="locations")
if row and row[0]['min_date'] and row[0]['max_date']:
return row[0]['min_date'], row[0]['max_date']
else:
@ -416,6 +423,7 @@ map.on(L.Draw.Event.CREATED, function (event) {
return m.get_root().render()
async def post_location(location: Location):
try:
context = location.context or {}
@ -438,14 +446,13 @@ async def post_location(location: Location):
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
'''
await API.execute_write_query(
await Db.execute_write(
query,
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
location.zip, location.street, action, device_type, device_model, device_name, device_os,
location.class_, location.type, location.name, location.display_name,
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
location.suburb, location.county, location.country_code, location.country,
table_name="locations"
location.suburb, location.county, location.country_code, location.country
)
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")

View file

@ -420,97 +420,95 @@ if API.EXTENSIONS.courtlistener:
await cl_download_file(download_url, target_path, session)
debug(f"Downloaded {file_name} to {target_path}")
@serve.get("/s", response_class=HTMLResponse)
async def shortener_form(request: Request):
return templates.TemplateResponse("shortener.html", {"request": request})
if API.EXTENSIONS.url_shortener:
@serve.get("/s", response_class=HTMLResponse)
async def shortener_form(request: Request):
return templates.TemplateResponse("shortener.html", {"request": request})
@serve.post("/s")
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
@serve.post("/s")
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
if custom_code:
if len(custom_code) != 3 or not custom_code.isalnum():
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
if custom_code:
if len(custom_code) != 3 or not custom_code.isalnum():
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls")
if existing:
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls")
if existing:
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
short_code = custom_code
else:
chars = string.ascii_letters + string.digits
while True:
debug(f"FOUND THE ISSUE")
short_code = ''.join(random.choice(chars) for _ in range(3))
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls")
if not existing:
break
short_code = custom_code
else:
chars = string.ascii_letters + string.digits
while True:
debug(f"FOUND THE ISSUE")
short_code = ''.join(random.choice(chars) for _ in range(3))
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls")
if not existing:
break
await API.execute_write_query(
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
short_code, long_url,
table_name="short_urls"
)
await API.execute_write_query(
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
short_code, long_url,
table_name="short_urls"
)
short_url = f"https://sij.ai/{short_code}"
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
short_url = f"https://sij.ai/{short_code}"
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
@serve.get("/{short_code}")
async def redirect_short_url(short_code: str):
results = await API.execute_read_query(
'SELECT long_url FROM short_urls WHERE short_code = $1',
short_code,
table_name="short_urls"
)
@serve.get("/{short_code}")
async def redirect_short_url(short_code: str):
results = await API.execute_read_query(
'SELECT long_url FROM short_urls WHERE short_code = $1',
short_code,
table_name="short_urls"
)
if not results:
raise HTTPException(status_code=404, detail="Short URL not found")
if not results:
raise HTTPException(status_code=404, detail="Short URL not found")
long_url = results[0].get('long_url')
long_url = results[0].get('long_url')
if not long_url:
raise HTTPException(status_code=404, detail="Long URL not found")
if not long_url:
raise HTTPException(status_code=404, detail="Long URL not found")
# Increment click count (you may want to do this asynchronously)
await API.execute_write_query(
'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)',
short_code, datetime.now(),
table_name="click_logs"
)
# Increment click count (you may want to do this asynchronously)
await API.execute_write_query(
'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)',
short_code, datetime.now(),
table_name="click_logs"
)
return RedirectResponse(url=long_url)
return RedirectResponse(url=long_url)
@serve.get("/analytics/{short_code}")
async def get_analytics(short_code: str):
url_info = await API.execute_read_query(
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
short_code,
table_name="short_urls"
)
if not url_info:
raise HTTPException(status_code=404, detail="Short URL not found")
@serve.get("/analytics/{short_code}")
async def get_analytics(short_code: str):
url_info = await API.execute_read_query(
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
short_code,
table_name="short_urls"
)
if not url_info:
raise HTTPException(status_code=404, detail="Short URL not found")
click_count = await API.execute_read_query(
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
short_code,
table_name="click_logs"
)
clicks = await API.execute_read_query(
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
short_code,
table_name="click_logs"
)
return {
"short_code": short_code,
"long_url": url_info['long_url'],
"created_at": url_info['created_at'],
"total_clicks": click_count,
"recent_clicks": [dict(click) for click in clicks]
}
click_count = await API.execute_read_query(
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
short_code,
table_name="click_logs"
)
clicks = await API.execute_read_query(
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
short_code,
table_name="click_logs"
)
return {
"short_code": short_code,
"long_url": url_info['long_url'],
"created_at": url_info['created_at'],
"total_clicks": click_count,
"recent_clicks": [dict(click) for click in clicks]
}

View file

@ -8,7 +8,7 @@ import httpx
import socket
from fastapi import APIRouter
from tailscale import Tailscale
from sijapi import L, API, TS_ID, SUBNET_BROADCAST
from sijapi import L, API, TS_ID
sys = APIRouter(tags=["public", "trusted", "private"])
logger = L.get_module_logger("health")
@ -36,7 +36,7 @@ def get_local_ip():
"""Get the server's local IP address."""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect((f'{SUBNET_BROADCAST}', 1))
s.connect((f'{API.SUBNET_BROADCAST}', 1))
IP = s.getsockname()[0]
except Exception:
IP = '127.0.0.1'