Auto-update: Thu Aug 8 23:57:50 PDT 2024
This commit is contained in:
parent
46c5db23de
commit
487807bab1
7 changed files with 153 additions and 122 deletions
|
@ -26,7 +26,8 @@ Db = Database.load('sys')
|
|||
|
||||
# HOST = f"{API.BIND}:{API.PORT}"
|
||||
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
|
||||
# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
|
||||
|
||||
SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
|
||||
|
||||
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())
|
||||
|
||||
|
|
|
@ -86,6 +86,7 @@ app.add_middleware(
|
|||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
|
||||
class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
client_ip = ipaddress.ip_address(request.client.host)
|
||||
|
@ -97,17 +98,36 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
|||
if not any(client_ip in subnet for subnet in trusted_subnets):
|
||||
api_key_header = request.headers.get("Authorization")
|
||||
api_key_query = request.query_params.get("api_key")
|
||||
|
||||
# Debug logging for API keys
|
||||
debug(f"API.KEYS: {API.KEYS}")
|
||||
|
||||
if api_key_header:
|
||||
api_key_header = api_key_header.lower().split("bearer ")[-1]
|
||||
debug(f"API key provided in header: {api_key_header}")
|
||||
if api_key_query:
|
||||
debug(f"API key provided in query: {api_key_query}")
|
||||
|
||||
if api_key_header not in API.KEYS and api_key_query not in API.KEYS:
|
||||
err(f"Invalid API key provided by a requester.")
|
||||
if api_key_header:
|
||||
debug(f"Invalid API key in header: {api_key_header}")
|
||||
if api_key_query:
|
||||
debug(f"Invalid API key in query: {api_key_query}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or missing API key"}
|
||||
)
|
||||
else:
|
||||
if api_key_header in API.KEYS:
|
||||
debug(f"Valid API key provided in header: {api_key_header}")
|
||||
if api_key_query in API.KEYS:
|
||||
debug(f"Valid API key provided in query: {api_key_query}")
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
# Add the middleware to your FastAPI app
|
||||
app.add_middleware(SimpleAPIKeyMiddleware)
|
||||
|
||||
|
|
|
@ -259,6 +259,10 @@ class DirConfig:
|
|||
|
||||
|
||||
class Database:
|
||||
@classmethod
|
||||
def load(cls, config_name: str):
|
||||
return cls(config_name)
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
self.config = self.load_config(config_path)
|
||||
self.pool_connections = {}
|
||||
|
|
|
@ -99,6 +99,7 @@ EXTENSIONS:
|
|||
courtlistener: off
|
||||
macnotify: on
|
||||
shellfish: on
|
||||
url_shortener: off
|
||||
|
||||
TZ: 'America/Los_Angeles'
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from folium.plugins import Fullscreen, MiniMap, MousePosition, Geocoder, Draw, M
|
|||
from zoneinfo import ZoneInfo
|
||||
from dateutil.parser import parse as dateutil_parse
|
||||
from typing import Optional, List, Union
|
||||
from sijapi import L, API, TZ, GEO
|
||||
from sijapi import L, API, Db, TZ, GEO
|
||||
from sijapi.classes import Location
|
||||
from sijapi.utilities import haversine, assemble_journal_path
|
||||
|
||||
|
@ -146,7 +146,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
|
|||
ORDER BY datetime DESC
|
||||
'''
|
||||
|
||||
locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations")
|
||||
locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
|
||||
|
||||
debug(f"Range locations query returned: {locations}")
|
||||
|
||||
|
@ -163,7 +163,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
|
|||
ORDER BY datetime DESC
|
||||
LIMIT 1
|
||||
'''
|
||||
location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations")
|
||||
location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None))
|
||||
debug(f"Fallback query returned: {location_data}")
|
||||
if location_data:
|
||||
locations = location_data
|
||||
|
@ -196,33 +196,40 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
|
|||
|
||||
return location_objects if location_objects else []
|
||||
|
||||
|
||||
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
|
||||
datetime = await dt(datetime)
|
||||
try:
|
||||
datetime = await dt(datetime)
|
||||
|
||||
debug(f"Fetching last location before {datetime}")
|
||||
|
||||
debug(f"Fetching last location before {datetime}")
|
||||
|
||||
query = '''
|
||||
SELECT id, datetime,
|
||||
ST_X(ST_AsText(location)::geometry) AS longitude,
|
||||
ST_Y(ST_AsText(location)::geometry) AS latitude,
|
||||
ST_Z(ST_AsText(location)::geometry) AS elevation,
|
||||
city, state, zip, street, country,
|
||||
action
|
||||
FROM locations
|
||||
WHERE datetime < $1
|
||||
ORDER BY datetime DESC
|
||||
LIMIT 1
|
||||
'''
|
||||
query = '''
|
||||
SELECT id, datetime,
|
||||
ST_X(ST_AsText(location)::geometry) AS longitude,
|
||||
ST_Y(ST_AsText(location)::geometry) AS latitude,
|
||||
ST_Z(ST_AsText(location)::geometry) AS elevation,
|
||||
city, state, zip, street, country,
|
||||
action
|
||||
FROM locations
|
||||
WHERE datetime < $1
|
||||
ORDER BY datetime DESC
|
||||
LIMIT 1
|
||||
'''
|
||||
|
||||
location_data = await Db.execute_read(query, datetime.replace(tzinfo=None))
|
||||
|
||||
location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations")
|
||||
|
||||
if location_data:
|
||||
debug(f"Last location found: {location_data[0]}")
|
||||
return Location(**location_data[0])
|
||||
else:
|
||||
debug("No location found before the specified datetime")
|
||||
if location_data:
|
||||
debug(f"Last location found: {location_data[0]}")
|
||||
return Location(**location_data[0])
|
||||
else:
|
||||
debug("No location found before the specified datetime")
|
||||
return None
|
||||
except Exception as e:
|
||||
error(f"Error fetching last location: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@gis.get("/map", response_class=HTMLResponse)
|
||||
async def generate_map_endpoint(
|
||||
start_date: Optional[str] = Query(None),
|
||||
|
@ -244,7 +251,7 @@ async def generate_map_endpoint(
|
|||
|
||||
async def get_date_range():
|
||||
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
|
||||
row = await API.execute_read_query(query, table_name="locations")
|
||||
row = await Db.execute_read(query, table_name="locations")
|
||||
if row and row[0]['min_date'] and row[0]['max_date']:
|
||||
return row[0]['min_date'], row[0]['max_date']
|
||||
else:
|
||||
|
@ -416,6 +423,7 @@ map.on(L.Draw.Event.CREATED, function (event) {
|
|||
|
||||
return m.get_root().render()
|
||||
|
||||
|
||||
async def post_location(location: Location):
|
||||
try:
|
||||
context = location.context or {}
|
||||
|
@ -438,14 +446,13 @@ async def post_location(location: Location):
|
|||
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
|
||||
'''
|
||||
|
||||
await API.execute_write_query(
|
||||
await Db.execute_write(
|
||||
query,
|
||||
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
|
||||
location.zip, location.street, action, device_type, device_model, device_name, device_os,
|
||||
location.class_, location.type, location.name, location.display_name,
|
||||
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
|
||||
location.suburb, location.county, location.country_code, location.country,
|
||||
table_name="locations"
|
||||
location.suburb, location.county, location.country_code, location.country
|
||||
)
|
||||
|
||||
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
|
||||
|
|
|
@ -420,97 +420,95 @@ if API.EXTENSIONS.courtlistener:
|
|||
await cl_download_file(download_url, target_path, session)
|
||||
debug(f"Downloaded {file_name} to {target_path}")
|
||||
|
||||
|
||||
@serve.get("/s", response_class=HTMLResponse)
|
||||
async def shortener_form(request: Request):
|
||||
return templates.TemplateResponse("shortener.html", {"request": request})
|
||||
|
||||
|
||||
@serve.post("/s")
|
||||
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
|
||||
|
||||
if custom_code:
|
||||
if len(custom_code) != 3 or not custom_code.isalnum():
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
|
||||
if API.EXTENSIONS.url_shortener:
|
||||
@serve.get("/s", response_class=HTMLResponse)
|
||||
async def shortener_form(request: Request):
|
||||
return templates.TemplateResponse("shortener.html", {"request": request})
|
||||
|
||||
|
||||
@serve.post("/s")
|
||||
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
|
||||
|
||||
if custom_code:
|
||||
if len(custom_code) != 3 or not custom_code.isalnum():
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
|
||||
|
||||
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls")
|
||||
if existing:
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
|
||||
|
||||
short_code = custom_code
|
||||
else:
|
||||
chars = string.ascii_letters + string.digits
|
||||
while True:
|
||||
debug(f"FOUND THE ISSUE")
|
||||
short_code = ''.join(random.choice(chars) for _ in range(3))
|
||||
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls")
|
||||
if not existing:
|
||||
break
|
||||
|
||||
await API.execute_write_query(
|
||||
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
|
||||
short_code, long_url,
|
||||
table_name="short_urls"
|
||||
)
|
||||
|
||||
short_url = f"https://sij.ai/{short_code}"
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
|
||||
|
||||
|
||||
@serve.get("/{short_code}")
|
||||
async def redirect_short_url(short_code: str):
|
||||
results = await API.execute_read_query(
|
||||
'SELECT long_url FROM short_urls WHERE short_code = $1',
|
||||
short_code,
|
||||
table_name="short_urls"
|
||||
)
|
||||
|
||||
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls")
|
||||
if existing:
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
|
||||
if not results:
|
||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||
|
||||
short_code = custom_code
|
||||
else:
|
||||
chars = string.ascii_letters + string.digits
|
||||
while True:
|
||||
debug(f"FOUND THE ISSUE")
|
||||
short_code = ''.join(random.choice(chars) for _ in range(3))
|
||||
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls")
|
||||
if not existing:
|
||||
break
|
||||
|
||||
await API.execute_write_query(
|
||||
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
|
||||
short_code, long_url,
|
||||
table_name="short_urls"
|
||||
)
|
||||
|
||||
short_url = f"https://sij.ai/{short_code}"
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
|
||||
|
||||
|
||||
@serve.get("/{short_code}")
|
||||
async def redirect_short_url(short_code: str):
|
||||
results = await API.execute_read_query(
|
||||
'SELECT long_url FROM short_urls WHERE short_code = $1',
|
||||
short_code,
|
||||
table_name="short_urls"
|
||||
)
|
||||
long_url = results[0].get('long_url')
|
||||
|
||||
if not long_url:
|
||||
raise HTTPException(status_code=404, detail="Long URL not found")
|
||||
|
||||
# Increment click count (you may want to do this asynchronously)
|
||||
await API.execute_write_query(
|
||||
'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)',
|
||||
short_code, datetime.now(),
|
||||
table_name="click_logs"
|
||||
)
|
||||
|
||||
return RedirectResponse(url=long_url)
|
||||
|
||||
if not results:
|
||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||
|
||||
long_url = results[0].get('long_url')
|
||||
|
||||
if not long_url:
|
||||
raise HTTPException(status_code=404, detail="Long URL not found")
|
||||
|
||||
# Increment click count (you may want to do this asynchronously)
|
||||
await API.execute_write_query(
|
||||
'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)',
|
||||
short_code, datetime.now(),
|
||||
table_name="click_logs"
|
||||
)
|
||||
|
||||
return RedirectResponse(url=long_url)
|
||||
|
||||
|
||||
@serve.get("/analytics/{short_code}")
|
||||
async def get_analytics(short_code: str):
|
||||
url_info = await API.execute_read_query(
|
||||
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
|
||||
short_code,
|
||||
table_name="short_urls"
|
||||
)
|
||||
if not url_info:
|
||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||
|
||||
click_count = await API.execute_read_query(
|
||||
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
|
||||
short_code,
|
||||
table_name="click_logs"
|
||||
)
|
||||
|
||||
clicks = await API.execute_read_query(
|
||||
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
|
||||
short_code,
|
||||
table_name="click_logs"
|
||||
)
|
||||
|
||||
return {
|
||||
"short_code": short_code,
|
||||
"long_url": url_info['long_url'],
|
||||
"created_at": url_info['created_at'],
|
||||
"total_clicks": click_count,
|
||||
"recent_clicks": [dict(click) for click in clicks]
|
||||
}
|
||||
|
||||
|
||||
@serve.get("/analytics/{short_code}")
|
||||
async def get_analytics(short_code: str):
|
||||
url_info = await API.execute_read_query(
|
||||
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
|
||||
short_code,
|
||||
table_name="short_urls"
|
||||
)
|
||||
if not url_info:
|
||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||
|
||||
click_count = await API.execute_read_query(
|
||||
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
|
||||
short_code,
|
||||
table_name="click_logs"
|
||||
)
|
||||
|
||||
clicks = await API.execute_read_query(
|
||||
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
|
||||
short_code,
|
||||
table_name="click_logs"
|
||||
)
|
||||
|
||||
return {
|
||||
"short_code": short_code,
|
||||
"long_url": url_info['long_url'],
|
||||
"created_at": url_info['created_at'],
|
||||
"total_clicks": click_count,
|
||||
"recent_clicks": [dict(click) for click in clicks]
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import httpx
|
|||
import socket
|
||||
from fastapi import APIRouter
|
||||
from tailscale import Tailscale
|
||||
from sijapi import L, API, TS_ID, SUBNET_BROADCAST
|
||||
from sijapi import L, API, TS_ID
|
||||
|
||||
sys = APIRouter(tags=["public", "trusted", "private"])
|
||||
logger = L.get_module_logger("health")
|
||||
|
@ -36,7 +36,7 @@ def get_local_ip():
|
|||
"""Get the server's local IP address."""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
s.connect((f'{SUBNET_BROADCAST}', 1))
|
||||
s.connect((f'{API.SUBNET_BROADCAST}', 1))
|
||||
IP = s.getsockname()[0]
|
||||
except Exception:
|
||||
IP = '127.0.0.1'
|
||||
|
|
Loading…
Reference in a new issue