Auto-update: Mon Jul 22 21:02:57 PDT 2024

This commit is contained in:
sanj 2024-07-22 21:02:57 -07:00
parent 866e6e31e7
commit 734ef67cc2
2 changed files with 160 additions and 62 deletions

View file

@ -19,9 +19,18 @@ from datetime import datetime, timedelta, timezone
from timezonefinder import TimezoneFinder
from zoneinfo import ZoneInfo
from srtm import get_data
from .logs import Logger
L = Logger("classes", "classes")
logger = L.get_module_logger("classes")
def debug(text: str): logger.debug(text)
def info(text: str): logger.info(text)
def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text)
T = TypeVar('T', bound='Configuration')
class Configuration(BaseModel):
HOME: Path = Path.home()
_dir_config: Optional['Configuration'] = None
@ -158,11 +167,16 @@ class Configuration(BaseModel):
extra = "allow"
arbitrary_types_allowed = True
from pydantic import BaseModel, create_model
from typing import Any, Dict, List, Union
from pathlib import Path
import yaml
import re
class PoolConfig(BaseModel):
ts_ip: str
ts_id: str
wan_ip: str
app_port: int
db_port: int
db_name: str
db_user: str
db_pass: str
class APIConfig(BaseModel):
HOST: str
@ -172,6 +186,7 @@ class APIConfig(BaseModel):
PUBLIC: List[str]
TRUSTED_SUBNETS: List[str]
MODULES: Any # This will be replaced with a dynamic model
POOL: List[Dict[str, Any]] # This replaces the separate PoolConfig
EXTENSIONS: Any # This will be replaced with a dynamic model
TZ: str
KEYS: List[str]
@ -285,6 +300,10 @@ class APIConfig(BaseModel):
return self.__dict__[name]
return super().__getattr__(name)
@property
def local_db(self):
return self.POOL[0]
@property
def active_modules(self) -> List[str]:
return [module for module, is_active in self.MODULES.__dict__.items() if is_active]
@ -293,6 +312,90 @@ class APIConfig(BaseModel):
def active_extensions(self) -> List[str]:
return [extension for extension, is_active in self.EXTENSIONS.__dict__.items() if is_active]
@asynccontextmanager
async def get_connection(self, pool_entry: Dict[str, Any] = None):
if pool_entry is None:
pool_entry = self.local_db
conn = await asyncpg.connect(
host=pool_entry['ts_ip'],
port=pool_entry['db_port'],
user=pool_entry['db_user'],
password=pool_entry['db_pass'],
database=pool_entry['db_name']
)
try:
yield conn
finally:
await conn.close()
async def push_changes(self, query: str, *args):
connections = []
try:
for pool_entry in self.POOL[1:]: # Skip the first (local) database
conn = await self.get_connection(pool_entry).__aenter__()
connections.append(conn)
results = await asyncio.gather(
*[conn.execute(query, *args) for conn in connections],
return_exceptions=True
)
for pool_entry, result in zip(self.POOL[1:], results):
if isinstance(result, Exception):
err(f"Failed to push to {pool_entry['ts_ip']}: {str(result)}")
else:
err(f"Successfully pushed to {pool_entry['ts_ip']}")
finally:
for conn in connections:
await conn.__aexit__(None, None, None)
async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
if source_pool_entry is None:
source_pool_entry = self.POOL[1] # Default to the second database in the pool
logger = Logger("DatabaseReplication")
async with self.get_connection(source_pool_entry) as source_conn:
async with self.get_connection() as dest_conn:
# This is a simplistic approach. You might need a more sophisticated
# method to determine what data needs to be synced.
tables = await source_conn.fetch(
"SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
)
for table in tables:
table_name = table['tablename']
await dest_conn.execute(f"TRUNCATE TABLE {table_name}")
rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
if rows:
columns = rows[0].keys()
await dest_conn.copy_records_to_table(
table_name, records=rows, columns=columns
)
info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
async def sync_schema(self):
source_entry = self.POOL[0] # Use the local database as the source
schema = await self.get_schema(source_entry)
for pool_entry in self.POOL[1:]:
await self.apply_schema(pool_entry, schema)
info(f"Synced schema to {pool_entry['ts_ip']}")
async def get_schema(self, pool_entry: Dict[str, Any]):
async with self.get_connection(pool_entry) as conn:
return await conn.fetch("SELECT * FROM information_schema.columns")
async def apply_schema(self, pool_entry: Dict[str, Any], schema):
async with self.get_connection(pool_entry) as conn:
# This is a simplified version. You'd need to handle creating/altering tables,
# adding/removing columns, changing data types, etc.
for table in schema:
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {table['table_name']} (
{table['column_name']} {table['data_type']}
)
""")
class Location(BaseModel):
latitude: float

View file

@ -433,30 +433,21 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
async def shortener_form(request: Request):
return templates.TemplateResponse("shortener.html", {"request": request})
@serve.post("/s")
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
async with DB.get_connection() as conn:
await conn.execute('''
CREATE TABLE IF NOT EXISTS short_urls (
short_code VARCHAR(3) PRIMARY KEY,
long_url TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
await create_tables(conn)
if custom_code:
if len(custom_code) != 3 or not custom_code.isalnum():
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
# Check if custom code already exists
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code)
if existing:
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
short_code = custom_code
else:
# Generate a random 3-character alphanumeric string
chars = string.ascii_letters + string.digits
while True:
short_code = ''.join(random.choice(chars) for _ in range(3))
@ -472,42 +463,23 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c
short_url = f"https://sij.ai/{short_code}"
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
@serve.get("/s", response_class=HTMLResponse)
async def shortener_form(request: Request):
return templates.TemplateResponse("shortener.html", {"request": request})
@serve.post("/s")
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
async with DB.get_connection() as conn:
await create_short_urls_table(conn)
if custom_code:
if len(custom_code) != 3 or not custom_code.isalnum():
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
# Check if custom code already exists
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code)
if existing:
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
short_code = custom_code
else:
# Generate a random 3-character alphanumeric string
chars = string.ascii_letters + string.digits
while True:
short_code = ''.join(random.choice(chars) for _ in range(3))
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', short_code)
if not existing:
break
await conn.execute(
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
short_code, long_url
async def create_tables(conn):
await conn.execute('''
CREATE TABLE IF NOT EXISTS short_urls (
short_code VARCHAR(3) PRIMARY KEY,
long_url TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
short_url = f"https://sij.ai/{short_code}"
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
''')
await conn.execute('''
CREATE TABLE IF NOT EXISTS click_logs (
id SERIAL PRIMARY KEY,
short_code VARCHAR(3) REFERENCES short_urls(short_code),
clicked_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ip_address TEXT,
user_agent TEXT
)
''')
@serve.get("/{short_code}", response_class=RedirectResponse, status_code=301)
async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)):
@ -520,16 +492,39 @@ async def redirect_short_url(request: Request, short_code: str = PathParam(...,
short_code
)
if result:
return result['long_url']
else:
raise HTTPException(status_code=404, detail="Short URL not found")
if result:
await conn.execute(
'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)',
short_code, request.client.host, request.headers.get("user-agent")
)
return result['long_url']
else:
raise HTTPException(status_code=404, detail="Short URL not found")
async def create_short_urls_table(conn):
await conn.execute('''
CREATE TABLE IF NOT EXISTS short_urls (
short_code VARCHAR(3) PRIMARY KEY,
long_url TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@serve.get("/analytics/{short_code}")
async def get_analytics(short_code: str):
async with DB.get_connection() as conn:
url_info = await conn.fetchrow(
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
short_code
)
''')
if not url_info:
raise HTTPException(status_code=404, detail="Short URL not found")
click_count = await conn.fetchval(
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
short_code
)
clicks = await conn.fetch(
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
short_code
)
return {
"short_code": short_code,
"long_url": url_info['long_url'],
"created_at": url_info['created_at'],
"total_clicks": click_count,
"recent_clicks": [dict(click) for click in clicks]
}