From 487807bab1a33747be50b73bda801905e4cff875 Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Thu, 8 Aug 2024 23:57:50 -0700 Subject: [PATCH] Auto-update: Thu Aug 8 23:57:50 PDT 2024 --- sijapi/__init__.py | 3 +- sijapi/__main__.py | 20 ++++ sijapi/classes.py | 4 + sijapi/config/sys.yaml-example | 1 + sijapi/routers/gis.py | 65 ++++++------ sijapi/routers/serve.py | 178 ++++++++++++++++----------------- sijapi/routers/sys.py | 4 +- 7 files changed, 153 insertions(+), 122 deletions(-) diff --git a/sijapi/__init__.py b/sijapi/__init__.py index 79121bb..ab6db13 100644 --- a/sijapi/__init__.py +++ b/sijapi/__init__.py @@ -26,7 +26,8 @@ Db = Database.load('sys') # HOST = f"{API.BIND}:{API.PORT}" # LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost'] -# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255') + +SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255') MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count()) diff --git a/sijapi/__main__.py b/sijapi/__main__.py index 1eea4ec..b8df014 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -86,6 +86,7 @@ app.add_middleware( allow_headers=['*'], ) + class SimpleAPIKeyMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): client_ip = ipaddress.ip_address(request.client.host) @@ -97,17 +98,36 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware): if not any(client_ip in subnet for subnet in trusted_subnets): api_key_header = request.headers.get("Authorization") api_key_query = request.query_params.get("api_key") + + # Debug logging for API keys + debug(f"API.KEYS: {API.KEYS}") + if api_key_header: api_key_header = api_key_header.lower().split("bearer ")[-1] + debug(f"API key provided in header: {api_key_header}") + if api_key_query: + debug(f"API key provided in query: {api_key_query}") + if api_key_header not in API.KEYS and api_key_query not in API.KEYS: err(f"Invalid API key provided by a requester.") + if api_key_header: + debug(f"Invalid API key in header: {api_key_header}") + if api_key_query: + debug(f"Invalid API key in query: {api_key_query}") return JSONResponse( status_code=401, content={"detail": "Invalid or missing API key"} ) + else: + if api_key_header in API.KEYS: + debug(f"Valid API key provided in header: {api_key_header}") + if api_key_query in API.KEYS: + debug(f"Valid API key provided in query: {api_key_query}") + response = await call_next(request) return response + # Add the middleware to your FastAPI app app.add_middleware(SimpleAPIKeyMiddleware) diff --git a/sijapi/classes.py b/sijapi/classes.py index cec786b..d0a0d1e 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -259,6 +259,10 @@ class DirConfig: class Database: + @classmethod + def load(cls, config_name: str): + return cls(config_name) + def __init__(self, config_path: str): self.config = self.load_config(config_path) self.pool_connections = {} diff --git a/sijapi/config/sys.yaml-example b/sijapi/config/sys.yaml-example index 0f23a68..f568163 100644 --- a/sijapi/config/sys.yaml-example +++ b/sijapi/config/sys.yaml-example @@ -99,6 +99,7 @@ EXTENSIONS: courtlistener: off macnotify: on shellfish: on + url_shortener: off TZ: 'America/Los_Angeles' diff --git a/sijapi/routers/gis.py b/sijapi/routers/gis.py index 170b85e..363f8c4 100644 --- a/sijapi/routers/gis.py +++ b/sijapi/routers/gis.py @@ -16,7 +16,7 @@ from folium.plugins import Fullscreen, MiniMap, MousePosition, Geocoder, Draw, M from zoneinfo import ZoneInfo from dateutil.parser import parse as dateutil_parse from typing import Optional, List, Union -from sijapi import L, API, TZ, GEO +from sijapi import L, API, Db, TZ, GEO from sijapi.classes import Location from sijapi.utilities import haversine, assemble_journal_path @@ -146,7 +146,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, ORDER BY datetime DESC ''' - locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations") + locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None)) debug(f"Range locations query returned: {locations}") @@ -163,7 +163,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, ORDER BY datetime DESC LIMIT 1 ''' - location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations") + location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None)) debug(f"Fallback query returned: {location_data}") if location_data: locations = location_data @@ -196,33 +196,40 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, return location_objects if location_objects else [] + async def fetch_last_location_before(datetime: datetime) -> Optional[Location]: - datetime = await dt(datetime) + try: + datetime = await dt(datetime) + + debug(f"Fetching last location before {datetime}") - debug(f"Fetching last location before {datetime}") - - query = ''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, country, - action - FROM locations - WHERE datetime < $1 - ORDER BY datetime DESC - LIMIT 1 - ''' + query = ''' + SELECT id, datetime, + ST_X(ST_AsText(location)::geometry) AS longitude, + ST_Y(ST_AsText(location)::geometry) AS latitude, + ST_Z(ST_AsText(location)::geometry) AS elevation, + city, state, zip, street, country, + action + FROM locations + WHERE datetime < $1 + ORDER BY datetime DESC + LIMIT 1 + ''' + + location_data = await Db.execute_read(query, datetime.replace(tzinfo=None)) - location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations") - - if location_data: - debug(f"Last location found: {location_data[0]}") - return Location(**location_data[0]) - else: - debug("No location found before the specified datetime") + if location_data: + debug(f"Last location found: {location_data[0]}") + return Location(**location_data[0]) + else: + debug("No location found before the specified datetime") + return None + except Exception as e: + error(f"Error fetching last location: {str(e)}") return None + + @gis.get("/map", response_class=HTMLResponse) async def generate_map_endpoint( start_date: Optional[str] = Query(None), @@ -244,7 +251,7 @@ async def generate_map_endpoint( async def get_date_range(): query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations" - row = await API.execute_read_query(query, table_name="locations") + row = await Db.execute_read(query, table_name="locations") if row and row[0]['min_date'] and row[0]['max_date']: return row[0]['min_date'], row[0]['max_date'] else: @@ -416,6 +423,7 @@ map.on(L.Draw.Event.CREATED, function (event) { return m.get_root().render() + async def post_location(location: Location): try: context = location.context or {} @@ -438,14 +446,13 @@ async def post_location(location: Location): $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) ''' - await API.execute_write_query( + await Db.execute_write( query, localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state, location.zip, location.street, action, device_type, device_model, device_name, device_os, location.class_, location.type, location.name, location.display_name, location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood, - location.suburb, location.county, location.country_code, location.country, - table_name="locations" + location.suburb, location.county, location.country_code, location.country ) info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") diff --git a/sijapi/routers/serve.py b/sijapi/routers/serve.py index 20998b1..8a73218 100644 --- a/sijapi/routers/serve.py +++ b/sijapi/routers/serve.py @@ -420,97 +420,95 @@ if API.EXTENSIONS.courtlistener: await cl_download_file(download_url, target_path, session) debug(f"Downloaded {file_name} to {target_path}") - -@serve.get("/s", response_class=HTMLResponse) -async def shortener_form(request: Request): - return templates.TemplateResponse("shortener.html", {"request": request}) - - -@serve.post("/s") -async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)): - - if custom_code: - if len(custom_code) != 3 or not custom_code.isalnum(): - return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"}) +if API.EXTENSIONS.url_shortener: + @serve.get("/s", response_class=HTMLResponse) + async def shortener_form(request: Request): + return templates.TemplateResponse("shortener.html", {"request": request}) + + + @serve.post("/s") + async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)): + + if custom_code: + if len(custom_code) != 3 or not custom_code.isalnum(): + return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"}) + + existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls") + if existing: + return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"}) + + short_code = custom_code + else: + chars = string.ascii_letters + string.digits + while True: + debug(f"FOUND THE ISSUE") + short_code = ''.join(random.choice(chars) for _ in range(3)) + existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls") + if not existing: + break + + await API.execute_write_query( + 'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)', + short_code, long_url, + table_name="short_urls" + ) + + short_url = f"https://sij.ai/{short_code}" + return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url}) + + + @serve.get("/{short_code}") + async def redirect_short_url(short_code: str): + results = await API.execute_read_query( + 'SELECT long_url FROM short_urls WHERE short_code = $1', + short_code, + table_name="short_urls" + ) - existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls") - if existing: - return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"}) + if not results: + raise HTTPException(status_code=404, detail="Short URL not found") - short_code = custom_code - else: - chars = string.ascii_letters + string.digits - while True: - debug(f"FOUND THE ISSUE") - short_code = ''.join(random.choice(chars) for _ in range(3)) - existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls") - if not existing: - break - - await API.execute_write_query( - 'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)', - short_code, long_url, - table_name="short_urls" - ) - - short_url = f"https://sij.ai/{short_code}" - return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url}) - - -@serve.get("/{short_code}") -async def redirect_short_url(short_code: str): - results = await API.execute_read_query( - 'SELECT long_url FROM short_urls WHERE short_code = $1', - short_code, - table_name="short_urls" - ) + long_url = results[0].get('long_url') + + if not long_url: + raise HTTPException(status_code=404, detail="Long URL not found") + + # Increment click count (you may want to do this asynchronously) + await API.execute_write_query( + 'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)', + short_code, datetime.now(), + table_name="click_logs" + ) + + return RedirectResponse(url=long_url) - if not results: - raise HTTPException(status_code=404, detail="Short URL not found") - long_url = results[0].get('long_url') - - if not long_url: - raise HTTPException(status_code=404, detail="Long URL not found") - - # Increment click count (you may want to do this asynchronously) - await API.execute_write_query( - 'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)', - short_code, datetime.now(), - table_name="click_logs" - ) - - return RedirectResponse(url=long_url) - - -@serve.get("/analytics/{short_code}") -async def get_analytics(short_code: str): - url_info = await API.execute_read_query( - 'SELECT long_url, created_at FROM short_urls WHERE short_code = $1', - short_code, - table_name="short_urls" - ) - if not url_info: - raise HTTPException(status_code=404, detail="Short URL not found") - - click_count = await API.execute_read_query( - 'SELECT COUNT(*) FROM click_logs WHERE short_code = $1', - short_code, - table_name="click_logs" - ) - - clicks = await API.execute_read_query( - 'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100', - short_code, - table_name="click_logs" - ) - - return { - "short_code": short_code, - "long_url": url_info['long_url'], - "created_at": url_info['created_at'], - "total_clicks": click_count, - "recent_clicks": [dict(click) for click in clicks] - } - - + @serve.get("/analytics/{short_code}") + async def get_analytics(short_code: str): + url_info = await API.execute_read_query( + 'SELECT long_url, created_at FROM short_urls WHERE short_code = $1', + short_code, + table_name="short_urls" + ) + if not url_info: + raise HTTPException(status_code=404, detail="Short URL not found") + + click_count = await API.execute_read_query( + 'SELECT COUNT(*) FROM click_logs WHERE short_code = $1', + short_code, + table_name="click_logs" + ) + + clicks = await API.execute_read_query( + 'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100', + short_code, + table_name="click_logs" + ) + + return { + "short_code": short_code, + "long_url": url_info['long_url'], + "created_at": url_info['created_at'], + "total_clicks": click_count, + "recent_clicks": [dict(click) for click in clicks] + } diff --git a/sijapi/routers/sys.py b/sijapi/routers/sys.py index 5a81fb5..d7f8a26 100644 --- a/sijapi/routers/sys.py +++ b/sijapi/routers/sys.py @@ -8,7 +8,7 @@ import httpx import socket from fastapi import APIRouter from tailscale import Tailscale -from sijapi import L, API, TS_ID, SUBNET_BROADCAST +from sijapi import L, API, TS_ID sys = APIRouter(tags=["public", "trusted", "private"]) logger = L.get_module_logger("health") @@ -36,7 +36,7 @@ def get_local_ip(): """Get the server's local IP address.""" s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: - s.connect((f'{SUBNET_BROADCAST}', 1)) + s.connect((f'{API.SUBNET_BROADCAST}', 1)) IP = s.getsockname()[0] except Exception: IP = '127.0.0.1'