Auto-update: Mon Jul 22 21:02:57 PDT 2024
This commit is contained in:
parent
866e6e31e7
commit
734ef67cc2
2 changed files with 160 additions and 62 deletions
|
@ -19,9 +19,18 @@ from datetime import datetime, timedelta, timezone
|
||||||
from timezonefinder import TimezoneFinder
|
from timezonefinder import TimezoneFinder
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
from srtm import get_data
|
from srtm import get_data
|
||||||
|
from .logs import Logger
|
||||||
|
|
||||||
|
L = Logger("classes", "classes")
|
||||||
|
logger = L.get_module_logger("classes")
|
||||||
|
|
||||||
|
def debug(text: str): logger.debug(text)
|
||||||
|
def info(text: str): logger.info(text)
|
||||||
|
def warn(text: str): logger.warning(text)
|
||||||
|
def err(text: str): logger.error(text)
|
||||||
|
def crit(text: str): logger.critical(text)
|
||||||
|
|
||||||
T = TypeVar('T', bound='Configuration')
|
T = TypeVar('T', bound='Configuration')
|
||||||
|
|
||||||
class Configuration(BaseModel):
|
class Configuration(BaseModel):
|
||||||
HOME: Path = Path.home()
|
HOME: Path = Path.home()
|
||||||
_dir_config: Optional['Configuration'] = None
|
_dir_config: Optional['Configuration'] = None
|
||||||
|
@ -158,11 +167,16 @@ class Configuration(BaseModel):
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
from pydantic import BaseModel, create_model
|
|
||||||
from typing import Any, Dict, List, Union
|
class PoolConfig(BaseModel):
|
||||||
from pathlib import Path
|
ts_ip: str
|
||||||
import yaml
|
ts_id: str
|
||||||
import re
|
wan_ip: str
|
||||||
|
app_port: int
|
||||||
|
db_port: int
|
||||||
|
db_name: str
|
||||||
|
db_user: str
|
||||||
|
db_pass: str
|
||||||
|
|
||||||
class APIConfig(BaseModel):
|
class APIConfig(BaseModel):
|
||||||
HOST: str
|
HOST: str
|
||||||
|
@ -172,6 +186,7 @@ class APIConfig(BaseModel):
|
||||||
PUBLIC: List[str]
|
PUBLIC: List[str]
|
||||||
TRUSTED_SUBNETS: List[str]
|
TRUSTED_SUBNETS: List[str]
|
||||||
MODULES: Any # This will be replaced with a dynamic model
|
MODULES: Any # This will be replaced with a dynamic model
|
||||||
|
POOL: List[Dict[str, Any]] # This replaces the separate PoolConfig
|
||||||
EXTENSIONS: Any # This will be replaced with a dynamic model
|
EXTENSIONS: Any # This will be replaced with a dynamic model
|
||||||
TZ: str
|
TZ: str
|
||||||
KEYS: List[str]
|
KEYS: List[str]
|
||||||
|
@ -284,6 +299,10 @@ class APIConfig(BaseModel):
|
||||||
if name in ['MODULES', 'EXTENSIONS']:
|
if name in ['MODULES', 'EXTENSIONS']:
|
||||||
return self.__dict__[name]
|
return self.__dict__[name]
|
||||||
return super().__getattr__(name)
|
return super().__getattr__(name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def local_db(self):
|
||||||
|
return self.POOL[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active_modules(self) -> List[str]:
|
def active_modules(self) -> List[str]:
|
||||||
|
@ -292,6 +311,90 @@ class APIConfig(BaseModel):
|
||||||
@property
|
@property
|
||||||
def active_extensions(self) -> List[str]:
|
def active_extensions(self) -> List[str]:
|
||||||
return [extension for extension, is_active in self.EXTENSIONS.__dict__.items() if is_active]
|
return [extension for extension, is_active in self.EXTENSIONS.__dict__.items() if is_active]
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_connection(self, pool_entry: Dict[str, Any] = None):
|
||||||
|
if pool_entry is None:
|
||||||
|
pool_entry = self.local_db
|
||||||
|
|
||||||
|
conn = await asyncpg.connect(
|
||||||
|
host=pool_entry['ts_ip'],
|
||||||
|
port=pool_entry['db_port'],
|
||||||
|
user=pool_entry['db_user'],
|
||||||
|
password=pool_entry['db_pass'],
|
||||||
|
database=pool_entry['db_name']
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
async def push_changes(self, query: str, *args):
|
||||||
|
connections = []
|
||||||
|
try:
|
||||||
|
for pool_entry in self.POOL[1:]: # Skip the first (local) database
|
||||||
|
conn = await self.get_connection(pool_entry).__aenter__()
|
||||||
|
connections.append(conn)
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[conn.execute(query, *args) for conn in connections],
|
||||||
|
return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for pool_entry, result in zip(self.POOL[1:], results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
err(f"Failed to push to {pool_entry['ts_ip']}: {str(result)}")
|
||||||
|
else:
|
||||||
|
err(f"Successfully pushed to {pool_entry['ts_ip']}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for conn in connections:
|
||||||
|
await conn.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
|
||||||
|
if source_pool_entry is None:
|
||||||
|
source_pool_entry = self.POOL[1] # Default to the second database in the pool
|
||||||
|
|
||||||
|
logger = Logger("DatabaseReplication")
|
||||||
|
async with self.get_connection(source_pool_entry) as source_conn:
|
||||||
|
async with self.get_connection() as dest_conn:
|
||||||
|
# This is a simplistic approach. You might need a more sophisticated
|
||||||
|
# method to determine what data needs to be synced.
|
||||||
|
tables = await source_conn.fetch(
|
||||||
|
"SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
|
||||||
|
)
|
||||||
|
for table in tables:
|
||||||
|
table_name = table['tablename']
|
||||||
|
await dest_conn.execute(f"TRUNCATE TABLE {table_name}")
|
||||||
|
rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
|
||||||
|
if rows:
|
||||||
|
columns = rows[0].keys()
|
||||||
|
await dest_conn.copy_records_to_table(
|
||||||
|
table_name, records=rows, columns=columns
|
||||||
|
)
|
||||||
|
info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
|
||||||
|
|
||||||
|
async def sync_schema(self):
|
||||||
|
source_entry = self.POOL[0] # Use the local database as the source
|
||||||
|
schema = await self.get_schema(source_entry)
|
||||||
|
for pool_entry in self.POOL[1:]:
|
||||||
|
await self.apply_schema(pool_entry, schema)
|
||||||
|
info(f"Synced schema to {pool_entry['ts_ip']}")
|
||||||
|
|
||||||
|
async def get_schema(self, pool_entry: Dict[str, Any]):
|
||||||
|
async with self.get_connection(pool_entry) as conn:
|
||||||
|
return await conn.fetch("SELECT * FROM information_schema.columns")
|
||||||
|
|
||||||
|
async def apply_schema(self, pool_entry: Dict[str, Any], schema):
|
||||||
|
async with self.get_connection(pool_entry) as conn:
|
||||||
|
# This is a simplified version. You'd need to handle creating/altering tables,
|
||||||
|
# adding/removing columns, changing data types, etc.
|
||||||
|
for table in schema:
|
||||||
|
await conn.execute(f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {table['table_name']} (
|
||||||
|
{table['column_name']} {table['data_type']}
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
class Location(BaseModel):
|
class Location(BaseModel):
|
||||||
|
|
|
@ -433,30 +433,21 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
|
||||||
async def shortener_form(request: Request):
|
async def shortener_form(request: Request):
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request})
|
return templates.TemplateResponse("shortener.html", {"request": request})
|
||||||
|
|
||||||
|
|
||||||
@serve.post("/s")
|
@serve.post("/s")
|
||||||
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
|
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
|
||||||
async with DB.get_connection() as conn:
|
async with DB.get_connection() as conn:
|
||||||
await conn.execute('''
|
await create_tables(conn)
|
||||||
CREATE TABLE IF NOT EXISTS short_urls (
|
|
||||||
short_code VARCHAR(3) PRIMARY KEY,
|
|
||||||
long_url TEXT NOT NULL,
|
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
if custom_code:
|
if custom_code:
|
||||||
if len(custom_code) != 3 or not custom_code.isalnum():
|
if len(custom_code) != 3 or not custom_code.isalnum():
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
|
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
|
||||||
|
|
||||||
# Check if custom code already exists
|
|
||||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code)
|
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code)
|
||||||
if existing:
|
if existing:
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
|
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
|
||||||
|
|
||||||
short_code = custom_code
|
short_code = custom_code
|
||||||
else:
|
else:
|
||||||
# Generate a random 3-character alphanumeric string
|
|
||||||
chars = string.ascii_letters + string.digits
|
chars = string.ascii_letters + string.digits
|
||||||
while True:
|
while True:
|
||||||
short_code = ''.join(random.choice(chars) for _ in range(3))
|
short_code = ''.join(random.choice(chars) for _ in range(3))
|
||||||
|
@ -472,42 +463,23 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c
|
||||||
short_url = f"https://sij.ai/{short_code}"
|
short_url = f"https://sij.ai/{short_code}"
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
|
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
|
||||||
|
|
||||||
|
async def create_tables(conn):
|
||||||
@serve.get("/s", response_class=HTMLResponse)
|
await conn.execute('''
|
||||||
async def shortener_form(request: Request):
|
CREATE TABLE IF NOT EXISTS short_urls (
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request})
|
short_code VARCHAR(3) PRIMARY KEY,
|
||||||
|
long_url TEXT NOT NULL,
|
||||||
@serve.post("/s")
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
|
|
||||||
async with DB.get_connection() as conn:
|
|
||||||
await create_short_urls_table(conn)
|
|
||||||
|
|
||||||
if custom_code:
|
|
||||||
if len(custom_code) != 3 or not custom_code.isalnum():
|
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
|
|
||||||
|
|
||||||
# Check if custom code already exists
|
|
||||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code)
|
|
||||||
if existing:
|
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
|
|
||||||
|
|
||||||
short_code = custom_code
|
|
||||||
else:
|
|
||||||
# Generate a random 3-character alphanumeric string
|
|
||||||
chars = string.ascii_letters + string.digits
|
|
||||||
while True:
|
|
||||||
short_code = ''.join(random.choice(chars) for _ in range(3))
|
|
||||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', short_code)
|
|
||||||
if not existing:
|
|
||||||
break
|
|
||||||
|
|
||||||
await conn.execute(
|
|
||||||
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
|
|
||||||
short_code, long_url
|
|
||||||
)
|
)
|
||||||
|
''')
|
||||||
short_url = f"https://sij.ai/{short_code}"
|
await conn.execute('''
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
|
CREATE TABLE IF NOT EXISTS click_logs (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
short_code VARCHAR(3) REFERENCES short_urls(short_code),
|
||||||
|
clicked_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
ip_address TEXT,
|
||||||
|
user_agent TEXT
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
|
||||||
@serve.get("/{short_code}", response_class=RedirectResponse, status_code=301)
|
@serve.get("/{short_code}", response_class=RedirectResponse, status_code=301)
|
||||||
async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)):
|
async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)):
|
||||||
|
@ -520,16 +492,39 @@ async def redirect_short_url(request: Request, short_code: str = PathParam(...,
|
||||||
short_code
|
short_code
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
return result['long_url']
|
await conn.execute(
|
||||||
else:
|
'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)',
|
||||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
short_code, request.client.host, request.headers.get("user-agent")
|
||||||
|
)
|
||||||
|
return result['long_url']
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||||
|
|
||||||
async def create_short_urls_table(conn):
|
@serve.get("/analytics/{short_code}")
|
||||||
await conn.execute('''
|
async def get_analytics(short_code: str):
|
||||||
CREATE TABLE IF NOT EXISTS short_urls (
|
async with DB.get_connection() as conn:
|
||||||
short_code VARCHAR(3) PRIMARY KEY,
|
url_info = await conn.fetchrow(
|
||||||
long_url TEXT NOT NULL,
|
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
short_code
|
||||||
)
|
)
|
||||||
''')
|
if not url_info:
|
||||||
|
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||||
|
|
||||||
|
click_count = await conn.fetchval(
|
||||||
|
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
|
||||||
|
short_code
|
||||||
|
)
|
||||||
|
|
||||||
|
clicks = await conn.fetch(
|
||||||
|
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
|
||||||
|
short_code
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"short_code": short_code,
|
||||||
|
"long_url": url_info['long_url'],
|
||||||
|
"created_at": url_info['created_at'],
|
||||||
|
"total_clicks": click_count,
|
||||||
|
"recent_clicks": [dict(click) for click in clicks]
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue