Auto-update: Mon Jul 22 21:02:57 PDT 2024
This commit is contained in:
parent
866e6e31e7
commit
734ef67cc2
2 changed files with 160 additions and 62 deletions
|
@ -19,9 +19,18 @@ from datetime import datetime, timedelta, timezone
|
|||
from timezonefinder import TimezoneFinder
|
||||
from zoneinfo import ZoneInfo
|
||||
from srtm import get_data
|
||||
from .logs import Logger
|
||||
|
||||
L = Logger("classes", "classes")
|
||||
logger = L.get_module_logger("classes")
|
||||
|
||||
def debug(text: str): logger.debug(text)
|
||||
def info(text: str): logger.info(text)
|
||||
def warn(text: str): logger.warning(text)
|
||||
def err(text: str): logger.error(text)
|
||||
def crit(text: str): logger.critical(text)
|
||||
|
||||
T = TypeVar('T', bound='Configuration')
|
||||
|
||||
class Configuration(BaseModel):
|
||||
HOME: Path = Path.home()
|
||||
_dir_config: Optional['Configuration'] = None
|
||||
|
@ -158,11 +167,16 @@ class Configuration(BaseModel):
|
|||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
from typing import Any, Dict, List, Union
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
import re
|
||||
|
||||
class PoolConfig(BaseModel):
|
||||
ts_ip: str
|
||||
ts_id: str
|
||||
wan_ip: str
|
||||
app_port: int
|
||||
db_port: int
|
||||
db_name: str
|
||||
db_user: str
|
||||
db_pass: str
|
||||
|
||||
class APIConfig(BaseModel):
|
||||
HOST: str
|
||||
|
@ -172,6 +186,7 @@ class APIConfig(BaseModel):
|
|||
PUBLIC: List[str]
|
||||
TRUSTED_SUBNETS: List[str]
|
||||
MODULES: Any # This will be replaced with a dynamic model
|
||||
POOL: List[Dict[str, Any]] # This replaces the separate PoolConfig
|
||||
EXTENSIONS: Any # This will be replaced with a dynamic model
|
||||
TZ: str
|
||||
KEYS: List[str]
|
||||
|
@ -285,6 +300,10 @@ class APIConfig(BaseModel):
|
|||
return self.__dict__[name]
|
||||
return super().__getattr__(name)
|
||||
|
||||
@property
|
||||
def local_db(self):
|
||||
return self.POOL[0]
|
||||
|
||||
@property
|
||||
def active_modules(self) -> List[str]:
|
||||
return [module for module, is_active in self.MODULES.__dict__.items() if is_active]
|
||||
|
@ -293,6 +312,90 @@ class APIConfig(BaseModel):
|
|||
def active_extensions(self) -> List[str]:
|
||||
return [extension for extension, is_active in self.EXTENSIONS.__dict__.items() if is_active]
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection(self, pool_entry: Dict[str, Any] = None):
|
||||
if pool_entry is None:
|
||||
pool_entry = self.local_db
|
||||
|
||||
conn = await asyncpg.connect(
|
||||
host=pool_entry['ts_ip'],
|
||||
port=pool_entry['db_port'],
|
||||
user=pool_entry['db_user'],
|
||||
password=pool_entry['db_pass'],
|
||||
database=pool_entry['db_name']
|
||||
)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
async def push_changes(self, query: str, *args):
|
||||
connections = []
|
||||
try:
|
||||
for pool_entry in self.POOL[1:]: # Skip the first (local) database
|
||||
conn = await self.get_connection(pool_entry).__aenter__()
|
||||
connections.append(conn)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[conn.execute(query, *args) for conn in connections],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
for pool_entry, result in zip(self.POOL[1:], results):
|
||||
if isinstance(result, Exception):
|
||||
err(f"Failed to push to {pool_entry['ts_ip']}: {str(result)}")
|
||||
else:
|
||||
err(f"Successfully pushed to {pool_entry['ts_ip']}")
|
||||
|
||||
finally:
|
||||
for conn in connections:
|
||||
await conn.__aexit__(None, None, None)
|
||||
|
||||
async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
|
||||
if source_pool_entry is None:
|
||||
source_pool_entry = self.POOL[1] # Default to the second database in the pool
|
||||
|
||||
logger = Logger("DatabaseReplication")
|
||||
async with self.get_connection(source_pool_entry) as source_conn:
|
||||
async with self.get_connection() as dest_conn:
|
||||
# This is a simplistic approach. You might need a more sophisticated
|
||||
# method to determine what data needs to be synced.
|
||||
tables = await source_conn.fetch(
|
||||
"SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
|
||||
)
|
||||
for table in tables:
|
||||
table_name = table['tablename']
|
||||
await dest_conn.execute(f"TRUNCATE TABLE {table_name}")
|
||||
rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
|
||||
if rows:
|
||||
columns = rows[0].keys()
|
||||
await dest_conn.copy_records_to_table(
|
||||
table_name, records=rows, columns=columns
|
||||
)
|
||||
info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
|
||||
|
||||
async def sync_schema(self):
|
||||
source_entry = self.POOL[0] # Use the local database as the source
|
||||
schema = await self.get_schema(source_entry)
|
||||
for pool_entry in self.POOL[1:]:
|
||||
await self.apply_schema(pool_entry, schema)
|
||||
info(f"Synced schema to {pool_entry['ts_ip']}")
|
||||
|
||||
async def get_schema(self, pool_entry: Dict[str, Any]):
|
||||
async with self.get_connection(pool_entry) as conn:
|
||||
return await conn.fetch("SELECT * FROM information_schema.columns")
|
||||
|
||||
async def apply_schema(self, pool_entry: Dict[str, Any], schema):
|
||||
async with self.get_connection(pool_entry) as conn:
|
||||
# This is a simplified version. You'd need to handle creating/altering tables,
|
||||
# adding/removing columns, changing data types, etc.
|
||||
for table in schema:
|
||||
await conn.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {table['table_name']} (
|
||||
{table['column_name']} {table['data_type']}
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
class Location(BaseModel):
|
||||
latitude: float
|
||||
|
|
|
@ -433,10 +433,37 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
|
|||
async def shortener_form(request: Request):
|
||||
return templates.TemplateResponse("shortener.html", {"request": request})
|
||||
|
||||
|
||||
@serve.post("/s")
|
||||
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
|
||||
async with DB.get_connection() as conn:
|
||||
await create_tables(conn)
|
||||
|
||||
if custom_code:
|
||||
if len(custom_code) != 3 or not custom_code.isalnum():
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
|
||||
|
||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code)
|
||||
if existing:
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
|
||||
|
||||
short_code = custom_code
|
||||
else:
|
||||
chars = string.ascii_letters + string.digits
|
||||
while True:
|
||||
short_code = ''.join(random.choice(chars) for _ in range(3))
|
||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', short_code)
|
||||
if not existing:
|
||||
break
|
||||
|
||||
await conn.execute(
|
||||
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
|
||||
short_code, long_url
|
||||
)
|
||||
|
||||
short_url = f"https://sij.ai/{short_code}"
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
|
||||
|
||||
async def create_tables(conn):
|
||||
await conn.execute('''
|
||||
CREATE TABLE IF NOT EXISTS short_urls (
|
||||
short_code VARCHAR(3) PRIMARY KEY,
|
||||
|
@ -444,70 +471,15 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c
|
|||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
if custom_code:
|
||||
if len(custom_code) != 3 or not custom_code.isalnum():
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
|
||||
|
||||
# Check if custom code already exists
|
||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code)
|
||||
if existing:
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
|
||||
|
||||
short_code = custom_code
|
||||
else:
|
||||
# Generate a random 3-character alphanumeric string
|
||||
chars = string.ascii_letters + string.digits
|
||||
while True:
|
||||
short_code = ''.join(random.choice(chars) for _ in range(3))
|
||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', short_code)
|
||||
if not existing:
|
||||
break
|
||||
|
||||
await conn.execute(
|
||||
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
|
||||
short_code, long_url
|
||||
await conn.execute('''
|
||||
CREATE TABLE IF NOT EXISTS click_logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
short_code VARCHAR(3) REFERENCES short_urls(short_code),
|
||||
clicked_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT
|
||||
)
|
||||
|
||||
short_url = f"https://sij.ai/{short_code}"
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
|
||||
|
||||
|
||||
@serve.get("/s", response_class=HTMLResponse)
|
||||
async def shortener_form(request: Request):
|
||||
return templates.TemplateResponse("shortener.html", {"request": request})
|
||||
|
||||
@serve.post("/s")
|
||||
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
|
||||
async with DB.get_connection() as conn:
|
||||
await create_short_urls_table(conn)
|
||||
|
||||
if custom_code:
|
||||
if len(custom_code) != 3 or not custom_code.isalnum():
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
|
||||
|
||||
# Check if custom code already exists
|
||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code)
|
||||
if existing:
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
|
||||
|
||||
short_code = custom_code
|
||||
else:
|
||||
# Generate a random 3-character alphanumeric string
|
||||
chars = string.ascii_letters + string.digits
|
||||
while True:
|
||||
short_code = ''.join(random.choice(chars) for _ in range(3))
|
||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', short_code)
|
||||
if not existing:
|
||||
break
|
||||
|
||||
await conn.execute(
|
||||
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
|
||||
short_code, long_url
|
||||
)
|
||||
|
||||
short_url = f"https://sij.ai/{short_code}"
|
||||
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
|
||||
''')
|
||||
|
||||
@serve.get("/{short_code}", response_class=RedirectResponse, status_code=301)
|
||||
async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)):
|
||||
|
@ -521,15 +493,38 @@ async def redirect_short_url(request: Request, short_code: str = PathParam(...,
|
|||
)
|
||||
|
||||
if result:
|
||||
await conn.execute(
|
||||
'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)',
|
||||
short_code, request.client.host, request.headers.get("user-agent")
|
||||
)
|
||||
return result['long_url']
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||
|
||||
async def create_short_urls_table(conn):
|
||||
await conn.execute('''
|
||||
CREATE TABLE IF NOT EXISTS short_urls (
|
||||
short_code VARCHAR(3) PRIMARY KEY,
|
||||
long_url TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
@serve.get("/analytics/{short_code}")
|
||||
async def get_analytics(short_code: str):
|
||||
async with DB.get_connection() as conn:
|
||||
url_info = await conn.fetchrow(
|
||||
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
|
||||
short_code
|
||||
)
|
||||
''')
|
||||
if not url_info:
|
||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||
|
||||
click_count = await conn.fetchval(
|
||||
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
|
||||
short_code
|
||||
)
|
||||
|
||||
clicks = await conn.fetch(
|
||||
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
|
||||
short_code
|
||||
)
|
||||
|
||||
return {
|
||||
"short_code": short_code,
|
||||
"long_url": url_info['long_url'],
|
||||
"created_at": url_info['created_at'],
|
||||
"total_clicks": click_count,
|
||||
"recent_clicks": [dict(click) for click in clicks]
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue