diff --git a/sijapi/classes.py b/sijapi/classes.py index 1958965..1aaad0a 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -19,9 +19,18 @@ from datetime import datetime, timedelta, timezone from timezonefinder import TimezoneFinder from zoneinfo import ZoneInfo from srtm import get_data +from .logs import Logger + +L = Logger("classes", "classes") +logger = L.get_module_logger("classes") + +def debug(text: str): logger.debug(text) +def info(text: str): logger.info(text) +def warn(text: str): logger.warning(text) +def err(text: str): logger.error(text) +def crit(text: str): logger.critical(text) T = TypeVar('T', bound='Configuration') - class Configuration(BaseModel): HOME: Path = Path.home() _dir_config: Optional['Configuration'] = None @@ -158,11 +167,16 @@ class Configuration(BaseModel): extra = "allow" arbitrary_types_allowed = True -from pydantic import BaseModel, create_model -from typing import Any, Dict, List, Union -from pathlib import Path -import yaml -import re + +class PoolConfig(BaseModel): + ts_ip: str + ts_id: str + wan_ip: str + app_port: int + db_port: int + db_name: str + db_user: str + db_pass: str class APIConfig(BaseModel): HOST: str @@ -172,6 +186,7 @@ class APIConfig(BaseModel): PUBLIC: List[str] TRUSTED_SUBNETS: List[str] MODULES: Any # This will be replaced with a dynamic model + POOL: List[Dict[str, Any]] # This replaces the separate PoolConfig EXTENSIONS: Any # This will be replaced with a dynamic model TZ: str KEYS: List[str] @@ -284,6 +299,10 @@ class APIConfig(BaseModel): if name in ['MODULES', 'EXTENSIONS']: return self.__dict__[name] return super().__getattr__(name) + + @property + def local_db(self): + return self.POOL[0] @property def active_modules(self) -> List[str]: @@ -292,6 +311,90 @@ class APIConfig(BaseModel): @property def active_extensions(self) -> List[str]: return [extension for extension, is_active in self.EXTENSIONS.__dict__.items() if is_active] + + @asynccontextmanager + async def get_connection(self, pool_entry: Dict[str, Any] = None): + if pool_entry is None: + pool_entry = self.local_db + + conn = await asyncpg.connect( + host=pool_entry['ts_ip'], + port=pool_entry['db_port'], + user=pool_entry['db_user'], + password=pool_entry['db_pass'], + database=pool_entry['db_name'] + ) + try: + yield conn + finally: + await conn.close() + + async def push_changes(self, query: str, *args): + connections = [] + try: + for pool_entry in self.POOL[1:]: # Skip the first (local) database + conn = await self.get_connection(pool_entry).__aenter__() + connections.append(conn) + + results = await asyncio.gather( + *[conn.execute(query, *args) for conn in connections], + return_exceptions=True + ) + + for pool_entry, result in zip(self.POOL[1:], results): + if isinstance(result, Exception): + err(f"Failed to push to {pool_entry['ts_ip']}: {str(result)}") + else: + err(f"Successfully pushed to {pool_entry['ts_ip']}") + + finally: + for conn in connections: + await conn.__aexit__(None, None, None) + + async def pull_changes(self, source_pool_entry: Dict[str, Any] = None): + if source_pool_entry is None: + source_pool_entry = self.POOL[1] # Default to the second database in the pool + + logger = Logger("DatabaseReplication") + async with self.get_connection(source_pool_entry) as source_conn: + async with self.get_connection() as dest_conn: + # This is a simplistic approach. You might need a more sophisticated + # method to determine what data needs to be synced. + tables = await source_conn.fetch( + "SELECT tablename FROM pg_tables WHERE schemaname = 'public'" + ) + for table in tables: + table_name = table['tablename'] + await dest_conn.execute(f"TRUNCATE TABLE {table_name}") + rows = await source_conn.fetch(f"SELECT * FROM {table_name}") + if rows: + columns = rows[0].keys() + await dest_conn.copy_records_to_table( + table_name, records=rows, columns=columns + ) + info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}") + + async def sync_schema(self): + source_entry = self.POOL[0] # Use the local database as the source + schema = await self.get_schema(source_entry) + for pool_entry in self.POOL[1:]: + await self.apply_schema(pool_entry, schema) + info(f"Synced schema to {pool_entry['ts_ip']}") + + async def get_schema(self, pool_entry: Dict[str, Any]): + async with self.get_connection(pool_entry) as conn: + return await conn.fetch("SELECT * FROM information_schema.columns") + + async def apply_schema(self, pool_entry: Dict[str, Any], schema): + async with self.get_connection(pool_entry) as conn: + # This is a simplified version. You'd need to handle creating/altering tables, + # adding/removing columns, changing data types, etc. + for table in schema: + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {table['table_name']} ( + {table['column_name']} {table['data_type']} + ) + """) class Location(BaseModel): diff --git a/sijapi/routers/serve.py b/sijapi/routers/serve.py index a9dcf02..118bd36 100644 --- a/sijapi/routers/serve.py +++ b/sijapi/routers/serve.py @@ -433,30 +433,21 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True: async def shortener_form(request: Request): return templates.TemplateResponse("shortener.html", {"request": request}) - @serve.post("/s") async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)): async with DB.get_connection() as conn: - await conn.execute(''' - CREATE TABLE IF NOT EXISTS short_urls ( - short_code VARCHAR(3) PRIMARY KEY, - long_url TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - ''') + await create_tables(conn) if custom_code: if len(custom_code) != 3 or not custom_code.isalnum(): return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"}) - # Check if custom code already exists existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code) if existing: return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"}) short_code = custom_code else: - # Generate a random 3-character alphanumeric string chars = string.ascii_letters + string.digits while True: short_code = ''.join(random.choice(chars) for _ in range(3)) @@ -472,42 +463,23 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c short_url = f"https://sij.ai/{short_code}" return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url}) - -@serve.get("/s", response_class=HTMLResponse) -async def shortener_form(request: Request): - return templates.TemplateResponse("shortener.html", {"request": request}) - -@serve.post("/s") -async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)): - async with DB.get_connection() as conn: - await create_short_urls_table(conn) - - if custom_code: - if len(custom_code) != 3 or not custom_code.isalnum(): - return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"}) - - # Check if custom code already exists - existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code) - if existing: - return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"}) - - short_code = custom_code - else: - # Generate a random 3-character alphanumeric string - chars = string.ascii_letters + string.digits - while True: - short_code = ''.join(random.choice(chars) for _ in range(3)) - existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', short_code) - if not existing: - break - - await conn.execute( - 'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)', - short_code, long_url +async def create_tables(conn): + await conn.execute(''' + CREATE TABLE IF NOT EXISTS short_urls ( + short_code VARCHAR(3) PRIMARY KEY, + long_url TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) - - short_url = f"https://sij.ai/{short_code}" - return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url}) + ''') + await conn.execute(''' + CREATE TABLE IF NOT EXISTS click_logs ( + id SERIAL PRIMARY KEY, + short_code VARCHAR(3) REFERENCES short_urls(short_code), + clicked_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + ip_address TEXT, + user_agent TEXT + ) + ''') @serve.get("/{short_code}", response_class=RedirectResponse, status_code=301) async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)): @@ -520,16 +492,39 @@ async def redirect_short_url(request: Request, short_code: str = PathParam(..., short_code ) - if result: - return result['long_url'] - else: - raise HTTPException(status_code=404, detail="Short URL not found") + if result: + await conn.execute( + 'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)', + short_code, request.client.host, request.headers.get("user-agent") + ) + return result['long_url'] + else: + raise HTTPException(status_code=404, detail="Short URL not found") -async def create_short_urls_table(conn): - await conn.execute(''' - CREATE TABLE IF NOT EXISTS short_urls ( - short_code VARCHAR(3) PRIMARY KEY, - long_url TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +@serve.get("/analytics/{short_code}") +async def get_analytics(short_code: str): + async with DB.get_connection() as conn: + url_info = await conn.fetchrow( + 'SELECT long_url, created_at FROM short_urls WHERE short_code = $1', + short_code ) - ''') \ No newline at end of file + if not url_info: + raise HTTPException(status_code=404, detail="Short URL not found") + + click_count = await conn.fetchval( + 'SELECT COUNT(*) FROM click_logs WHERE short_code = $1', + short_code + ) + + clicks = await conn.fetch( + 'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100', + short_code + ) + + return { + "short_code": short_code, + "long_url": url_info['long_url'], + "created_at": url_info['created_at'], + "total_clicks": click_count, + "recent_clicks": [dict(click) for click in clicks] + }