Auto-update: Wed Jul 31 13:41:15 PDT 2024
This commit is contained in:
parent
10eb581ad4
commit
011673893d
6 changed files with 276 additions and 128 deletions
|
@ -28,6 +28,7 @@ MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count
|
|||
IMG = Configuration.load('img', 'secrets')
|
||||
News = Configuration.load('news', 'secrets')
|
||||
Scrape = Configuration.load('scrape', 'secrets', Dir)
|
||||
Serve = Configuration.load('serve')
|
||||
|
||||
# Directories & general paths
|
||||
ROUTER_DIR = BASE_DIR / "routers"
|
||||
|
|
|
@ -56,11 +56,11 @@ class Configuration(BaseModel):
|
|||
with yaml_path.open('r') as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
info(f"Loaded configuration data from {yaml_path}")
|
||||
debug(f"Loaded configuration data from {yaml_path}")
|
||||
if secrets_path:
|
||||
with secrets_path.open('r') as file:
|
||||
secrets_data = yaml.safe_load(file)
|
||||
info(f"Loaded secrets data from {secrets_path}")
|
||||
debug(f"Loaded secrets data from {secrets_path}")
|
||||
if isinstance(config_data, list):
|
||||
for item in config_data:
|
||||
if isinstance(item, dict):
|
||||
|
@ -184,7 +184,14 @@ class APIConfig(BaseModel):
|
|||
|
||||
SPECIAL_TABLES: ClassVar[List[str]] = ['spatial_ref_sys']
|
||||
|
||||
_db_pools: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
db_pools: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._db_pools = {}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
|
||||
|
@ -307,17 +314,17 @@ class APIConfig(BaseModel):
|
|||
|
||||
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
|
||||
|
||||
if pool_key not in self._db_pools:
|
||||
if pool_key not in self.db_pools:
|
||||
try:
|
||||
self._db_pools[pool_key] = await asyncpg.create_pool(
|
||||
self.db_pools[pool_key] = await asyncpg.create_pool(
|
||||
host=pool_entry['ts_ip'],
|
||||
port=pool_entry['db_port'],
|
||||
user=pool_entry['db_user'],
|
||||
password=pool_entry['db_pass'],
|
||||
database=pool_entry['db_name'],
|
||||
min_size=1,
|
||||
max_size=10, # adjust as needed
|
||||
timeout=5 # connection timeout in seconds
|
||||
max_size=10,
|
||||
timeout=5
|
||||
)
|
||||
except Exception as e:
|
||||
err(f"Failed to create connection pool for {pool_key}: {str(e)}")
|
||||
|
@ -325,27 +332,70 @@ class APIConfig(BaseModel):
|
|||
return
|
||||
|
||||
try:
|
||||
async with self._db_pools[pool_key].acquire() as conn:
|
||||
async with self.db_pools[pool_key].acquire() as conn:
|
||||
yield conn
|
||||
except asyncpg.exceptions.ConnectionDoesNotExistError:
|
||||
err(f"Failed to acquire connection from pool for {pool_key}: Connection does not exist")
|
||||
yield None
|
||||
except asyncpg.exceptions.ConnectionFailureError:
|
||||
err(f"Failed to acquire connection from pool for {pool_key}: Connection failure")
|
||||
yield None
|
||||
except Exception as e:
|
||||
err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}")
|
||||
err(f"Failed to acquire connection from pool for {pool_key}: {str(e)}")
|
||||
yield None
|
||||
|
||||
async def push_changes_to_one(self, pool_entry):
|
||||
try:
|
||||
async with self.get_connection() as local_conn:
|
||||
if local_conn is None:
|
||||
err(f"Failed to connect to local database. Skipping push to {pool_entry['ts_id']}")
|
||||
return
|
||||
|
||||
async with self.get_connection(pool_entry) as remote_conn:
|
||||
if remote_conn is None:
|
||||
err(f"Failed to connect to remote database {pool_entry['ts_id']}. Skipping push.")
|
||||
return
|
||||
|
||||
tables = await local_conn.fetch("""
|
||||
SELECT tablename FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
""")
|
||||
|
||||
for table in tables:
|
||||
table_name = table['tablename']
|
||||
try:
|
||||
if table_name in self.SPECIAL_TABLES:
|
||||
await self.sync_special_table(local_conn, remote_conn, table_name)
|
||||
else:
|
||||
primary_key = await self.ensure_sync_columns(remote_conn, table_name)
|
||||
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
|
||||
|
||||
changes = await local_conn.fetch(f"""
|
||||
SELECT * FROM "{table_name}"
|
||||
WHERE version > $1 AND server_id = $2
|
||||
ORDER BY version ASC
|
||||
""", last_synced_version, os.environ.get('TS_ID'))
|
||||
|
||||
if changes:
|
||||
changes_count = await self.apply_batch_changes(remote_conn, table_name, changes, primary_key)
|
||||
|
||||
if changes_count > 0:
|
||||
debug(f"Pushed {changes_count} changes for table {table_name} to {pool_entry['ts_id']}")
|
||||
|
||||
except Exception as e:
|
||||
err(f"Error pushing changes for table {table_name} to {pool_entry['ts_id']}: {str(e)}")
|
||||
err(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
|
||||
|
||||
except Exception as e:
|
||||
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
||||
err(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
async def close_db_pools(self):
|
||||
info("Closing database connection pools...")
|
||||
for pool_key, pool in self._db_pools.items():
|
||||
for pool_key, pool in self.db_pools.items():
|
||||
try:
|
||||
await pool.close()
|
||||
info(f"Closed pool for {pool_key}")
|
||||
debug(f"Closed pool for {pool_key}")
|
||||
except Exception as e:
|
||||
err(f"Error closing pool for {pool_key}: {str(e)}")
|
||||
self._db_pools.clear()
|
||||
self.db_pools.clear()
|
||||
info("All database connection pools closed.")
|
||||
|
||||
async def initialize_sync(self):
|
||||
|
@ -360,7 +410,7 @@ class APIConfig(BaseModel):
|
|||
if conn is None:
|
||||
continue # Skip this database if connection failed
|
||||
|
||||
info(f"Starting sync initialization for {pool_entry['ts_ip']}...")
|
||||
debug(f"Starting sync initialization for {pool_entry['ts_ip']}...")
|
||||
|
||||
# Check PostGIS installation
|
||||
postgis_installed = await self.check_postgis(conn)
|
||||
|
@ -376,15 +426,19 @@ class APIConfig(BaseModel):
|
|||
table_name = table['tablename']
|
||||
await self.ensure_sync_columns(conn, table_name)
|
||||
|
||||
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.")
|
||||
debug(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.")
|
||||
|
||||
except Exception as e:
|
||||
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
|
||||
err(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def ensure_sync_columns(self, conn, table_name):
|
||||
if conn is None:
|
||||
debug(f"Skipping offline server...")
|
||||
return None
|
||||
|
||||
if table_name in self.SPECIAL_TABLES:
|
||||
info(f"Skipping sync columns for special table: {table_name}")
|
||||
debug(f"Skipping sync columns for special table: {table_name}")
|
||||
return None
|
||||
|
||||
try:
|
||||
|
@ -439,7 +493,7 @@ class APIConfig(BaseModel):
|
|||
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
|
||||
""")
|
||||
|
||||
info(f"Successfully ensured sync columns and trigger for table {table_name}")
|
||||
debug(f"Successfully ensured sync columns and trigger for table {table_name}")
|
||||
return primary_key
|
||||
|
||||
except Exception as e:
|
||||
|
@ -447,7 +501,8 @@ class APIConfig(BaseModel):
|
|||
err(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def apply_batch_changes(self, conn, table_name, changes, primary_key):
|
||||
if not changes:
|
||||
if conn is None or not changes:
|
||||
debug(f"Skipping apply_batch_changes because conn is none or there are no changes.")
|
||||
return 0
|
||||
|
||||
try:
|
||||
|
@ -473,12 +528,12 @@ class APIConfig(BaseModel):
|
|||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
|
||||
debug(f"Generated insert query for {table_name}: {insert_query}")
|
||||
# debug(f"Generated insert query for {table_name}: {insert_query}")
|
||||
|
||||
affected_rows = 0
|
||||
async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"):
|
||||
values = [change[col] for col in columns]
|
||||
debug(f"Executing query for {table_name} with values: {values}")
|
||||
# debug(f"Executing query for {table_name} with values: {values}")
|
||||
result = await conn.execute(insert_query, *values)
|
||||
affected_rows += int(result.split()[-1])
|
||||
|
||||
|
@ -491,7 +546,7 @@ class APIConfig(BaseModel):
|
|||
|
||||
async def pull_changes(self, source_pool_entry, batch_size=10000):
|
||||
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
||||
info("Skipping self-sync")
|
||||
debug("Skipping self-sync")
|
||||
return 0
|
||||
|
||||
total_changes = 0
|
||||
|
@ -533,7 +588,7 @@ class APIConfig(BaseModel):
|
|||
if changes_count > 0:
|
||||
info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
|
||||
else:
|
||||
info(f"No changes to sync for {table_name}")
|
||||
debug(f"No changes to sync for {table_name}")
|
||||
|
||||
except Exception as e:
|
||||
err(f"Error syncing table {table_name}: {str(e)}")
|
||||
|
@ -572,47 +627,15 @@ class APIConfig(BaseModel):
|
|||
except Exception as e:
|
||||
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
||||
|
||||
async def push_changes_to_one(self, pool_entry):
|
||||
try:
|
||||
async with self.get_connection() as local_conn:
|
||||
async with self.get_connection(pool_entry) as remote_conn:
|
||||
tables = await local_conn.fetch("""
|
||||
SELECT tablename FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
""")
|
||||
|
||||
for table in tables:
|
||||
table_name = table['tablename']
|
||||
try:
|
||||
if table_name in self.SPECIAL_TABLES:
|
||||
await self.sync_special_table(local_conn, remote_conn, table_name)
|
||||
else:
|
||||
primary_key = await self.ensure_sync_columns(remote_conn, table_name)
|
||||
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
|
||||
|
||||
changes = await local_conn.fetch(f"""
|
||||
SELECT * FROM "{table_name}"
|
||||
WHERE version > $1 AND server_id = $2
|
||||
ORDER BY version ASC
|
||||
""", last_synced_version, os.environ.get('TS_ID'))
|
||||
|
||||
if changes:
|
||||
changes_count = await self.apply_batch_changes(remote_conn, table_name, changes, primary_key)
|
||||
|
||||
if changes_count > 0:
|
||||
info(f"Pushed {changes_count} changes for table {table_name} to {pool_entry['ts_id']}")
|
||||
|
||||
except Exception as e:
|
||||
err(f"Error pushing changes for table {table_name} to {pool_entry['ts_id']}: {str(e)}")
|
||||
err(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
|
||||
except Exception as e:
|
||||
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
||||
err(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def get_last_synced_version(self, conn, table_name, server_id):
|
||||
if conn is None:
|
||||
debug(f"Skipping offline server...")
|
||||
return 0
|
||||
|
||||
if table_name in self.SPECIAL_TABLES:
|
||||
debug(f"Skipping get_last_synced_version becaue {table_name} is special.")
|
||||
return 0 # Special handling for tables without version column
|
||||
|
||||
return await conn.fetchval(f"""
|
||||
|
@ -622,10 +645,14 @@ class APIConfig(BaseModel):
|
|||
""", server_id)
|
||||
|
||||
async def check_postgis(self, conn):
|
||||
if conn is None:
|
||||
debug(f"Skipping offline server...")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await conn.fetchval("SELECT PostGIS_version();")
|
||||
if result:
|
||||
info(f"PostGIS version: {result}")
|
||||
debug(f"PostGIS version: {result}")
|
||||
return True
|
||||
else:
|
||||
warn("PostGIS is not installed or not working properly")
|
||||
|
@ -669,7 +696,7 @@ class APIConfig(BaseModel):
|
|||
INSERT INTO spatial_ref_sys ({', '.join(f'"{col}"' for col in columns)})
|
||||
VALUES ({', '.join(placeholders)})
|
||||
"""
|
||||
debug(f"Inserting new entry for srid {srid}: {insert_query}")
|
||||
# debug(f"Inserting new entry for srid {srid}: {insert_query}")
|
||||
await dest_conn.execute(insert_query, *source_entry.values())
|
||||
inserts += 1
|
||||
elif source_entry != dest_dict[srid]:
|
||||
|
@ -682,7 +709,7 @@ class APIConfig(BaseModel):
|
|||
proj4text = $4::text
|
||||
WHERE srid = $5::integer
|
||||
"""
|
||||
debug(f"Updating entry for srid {srid}: {update_query}")
|
||||
# debug(f"Updating entry for srid {srid}: {update_query}")
|
||||
await dest_conn.execute(update_query,
|
||||
source_entry['auth_name'],
|
||||
source_entry['auth_srid'],
|
||||
|
@ -705,6 +732,10 @@ class APIConfig(BaseModel):
|
|||
max_version = -1
|
||||
local_ts_id = os.environ.get('TS_ID')
|
||||
online_hosts = await self.get_online_hosts()
|
||||
num_online_hosts = len(online_hosts)
|
||||
if num_online_hosts > 0:
|
||||
online_ts_ids = [host['ts_id'] for host in online_hosts if host['ts_id'] != local_ts_id]
|
||||
crit(f"Online hosts: {', '.join(online_ts_ids)}")
|
||||
|
||||
for pool_entry in online_hosts:
|
||||
if pool_entry['ts_id'] == local_ts_id:
|
||||
|
@ -737,7 +768,7 @@ class APIConfig(BaseModel):
|
|||
max_version = version
|
||||
most_recent_source = pool_entry
|
||||
else:
|
||||
info(f"No data in table {table_name} for {pool_entry['ts_id']}")
|
||||
debug(f"No data in table {table_name} for {pool_entry['ts_id']}")
|
||||
except asyncpg.exceptions.UndefinedColumnError:
|
||||
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.")
|
||||
except Exception as e:
|
||||
|
@ -750,10 +781,9 @@ class APIConfig(BaseModel):
|
|||
err(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
return most_recent_source
|
||||
|
||||
|
||||
|
||||
|
||||
else:
|
||||
warn(f"No other online hosts for sync")
|
||||
return None
|
||||
|
||||
|
||||
class Location(BaseModel):
|
||||
|
|
24
sijapi/config/cloudflare.yaml-example
Normal file
24
sijapi/config/cloudflare.yaml-example
Normal file
|
@ -0,0 +1,24 @@
|
|||
base_url: 'https://cloudflare.com'
|
||||
token: '{{ SECRET.CF_TOKEN }}'
|
||||
cf_ip: '65.1.1.1' # replace me
|
||||
ai: # tld, e.g. .ai
|
||||
sij: # domain, e.g. sij.ai
|
||||
_zone: c00a9e0ff540308232eb5762621d5b1 # zone id
|
||||
_www: 8a26b17923ac3a8f21b6127cdb3d7459 # dns id for domain, e.g. www.sij.ai
|
||||
api: 8a336ee8a5b13e112d6d4ae77c149bd6 # dns id for subdomain, e.g. api.sij.ai
|
||||
txt: b4b0bd48ac4272b1c48eb1624072adb2 # dns id for subdomaikn, e.g. txt.sij.ai
|
||||
git: cd524c00b6daf824c933a294cb52eae2 # dns id for subdomain, e.g. git.sij.ai
|
||||
law: # tld, e.g. .law
|
||||
sij: # domain, e.g. sij.law
|
||||
_zone: 5b68d9cd99b896e26c232f03cda89d66 # zone id
|
||||
_www: ba9afd99deeb0407ea1b74ba88eb5564 # dns id for domain, e.g. www.sij.law
|
||||
map: 4de8fe05bb0e722ee2c78b2ddf553c82 # dns id for subdomain, e.g. map.sij.law
|
||||
imap: 384acd03c139ffaed37f4e70c627e7d1 # dns id for subdomain, e.g. imap.sij.law
|
||||
smtp: 0677e42ea9b589d67d1da21aa00455e0 # dns id for subdomain, e.g. smtp.sij.law
|
||||
esq: # tld, e.g. .esq
|
||||
env: # domain, e.g. env.esq
|
||||
_zone: faf889fd7c227c2e61875b2e70b5c6fe # zone id
|
||||
_www: b9b636ce9bd4812a6564f572f0f373ee # dns id for domain, e.g. www.env.esq
|
||||
dt: afbc205e829cfb8d3f79dab187c06f99 # dns id for subdomain, e.g. dt.env.esq
|
||||
rss: f043d5cf485f4e53f9cbcb85fed2c861 # dns id for subdomain, e.g. rss.env.esq
|
||||
s3: a5fa431a4be8f50af2c118aed353b0ec # dns id for subdomain, e.g. s3.env.esq
|
37
sijapi/config/scrape.yaml-example
Normal file
37
sijapi/config/scrape.yaml-example
Normal file
|
@ -0,0 +1,37 @@
|
|||
- name: "CalFire_THP"
|
||||
url: "https://caltreesplans.resources.ca.gov/Caltrees/Report/ShowReport.aspx?module=TH_Document&reportID=492&reportType=LINK_REPORT_LIST"
|
||||
output_file: "{{ Dir.DATA }}/calfire_thp_data.json"
|
||||
content:
|
||||
type: "pdf"
|
||||
selector: null
|
||||
js_render: false
|
||||
processing:
|
||||
- name: "split_entries"
|
||||
type: "regex_split"
|
||||
pattern: '(\d+-\d+-\d+-\w+)'
|
||||
- name: "filter_entries"
|
||||
type: "keyword_filter"
|
||||
keywords: ["Sierra Pacific", "SPI", "Land & Timber"]
|
||||
- name: "extract_data"
|
||||
type: "regex_extract"
|
||||
extractions:
|
||||
- name: "Harvest Document"
|
||||
pattern: '(\d+-\d+-\d+-\w+)'
|
||||
- name: "Land Owner"
|
||||
pattern: '((?:SIERRA PACIFIC|SPI|.*?LAND & TIMBER).*?)(?=\d+-\d+-\d+-\w+|\Z)'
|
||||
flags: ["DOTALL", "IGNORECASE"]
|
||||
- name: "Location"
|
||||
pattern: '((?:MDBM|HBM):.*?)(?=(?:SIERRA PACIFIC|SPI|.*?LAND & TIMBER)|\Z)'
|
||||
flags: ["DOTALL"]
|
||||
- name: "Total Acres"
|
||||
pattern: '(\d+\.\d+)\s+acres'
|
||||
- name: "Watershed"
|
||||
pattern: 'Watershed:\s+(.+)'
|
||||
post_processing:
|
||||
- name: "extract_plss_coordinates"
|
||||
type: "regex_extract"
|
||||
field: "Location"
|
||||
pattern: '(\w+): T(\d+)([NSEW]) R(\d+)([NSEW]) S(\d+)'
|
||||
output_field: "PLSS Coordinates"
|
||||
all_matches: true
|
||||
format: "{0}: T{1}{2} R{3}{4} S{5}"
|
5
sijapi/config/serve.yaml-example
Normal file
5
sijapi/config/serve.yaml-example
Normal file
|
@ -0,0 +1,5 @@
|
|||
forwarding_rules:
|
||||
- source: "test.domain.com:80"
|
||||
destination: "100.64.64.14:8080"
|
||||
- source: "100.64.64.20:1024"
|
||||
destination: "127.0.0.1:1025"
|
|
@ -31,7 +31,7 @@ from selenium.webdriver.common.by import By
|
|||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from sijapi import (
|
||||
L, API, LOGS_DIR, TS_ID, CASETABLE_PATH, COURTLISTENER_DOCKETS_URL, COURTLISTENER_API_KEY,
|
||||
L, API, Serve, LOGS_DIR, TS_ID, CASETABLE_PATH, COURTLISTENER_DOCKETS_URL, COURTLISTENER_API_KEY,
|
||||
COURTLISTENER_BASE_URL, COURTLISTENER_DOCKETS_DIR, COURTLISTENER_SEARCH_DIR, ALERTS_DIR,
|
||||
MAC_UN, MAC_PW, MAC_ID, TS_TAILNET, IMG_DIR, PUBLIC_KEY, OBSIDIAN_VAULT_DIR
|
||||
)
|
||||
|
@ -51,7 +51,7 @@ templates = Jinja2Templates(directory=Path(__file__).parent.parent / "sites")
|
|||
|
||||
@serve.get("/pgp")
|
||||
async def get_pgp():
|
||||
return Response(PUBLIC_KEY, media_type="text/plain")
|
||||
return Response(Serve.PGP, media_type="text/plain")
|
||||
|
||||
@serve.get("/img/{image_name}")
|
||||
def serve_image(image_name: str):
|
||||
|
@ -119,17 +119,6 @@ async def hook_alert(request: Request):
|
|||
|
||||
return await notify(alert)
|
||||
|
||||
@serve.post("/alert/cd")
|
||||
async def hook_changedetection(webhook_data: dict):
|
||||
body = webhook_data.get("body", {})
|
||||
message = body.get("message", "")
|
||||
|
||||
if message and any(word in message.split() for word in ["SPI", "sierra", "pacific"]):
|
||||
filename = ALERTS_DIR / f"alert_{int(time.time())}.json"
|
||||
filename.write_text(json.dumps(webhook_data, indent=4))
|
||||
notify(message)
|
||||
|
||||
return {"status": "received"}
|
||||
|
||||
async def notify(alert: str):
|
||||
fail = True
|
||||
|
@ -528,3 +517,65 @@ async def get_analytics(short_code: str):
|
|||
"total_clicks": click_count,
|
||||
"recent_clicks": [dict(click) for click in clicks]
|
||||
}
|
||||
|
||||
|
||||
|
||||
async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, destination: str):
|
||||
dest_host, dest_port = destination.split(':')
|
||||
dest_port = int(dest_port)
|
||||
|
||||
try:
|
||||
dest_reader, dest_writer = await asyncio.open_connection(dest_host, dest_port)
|
||||
except Exception as e:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
return
|
||||
|
||||
async def forward(src, dst):
|
||||
try:
|
||||
while True:
|
||||
data = await src.read(8192)
|
||||
if not data:
|
||||
break
|
||||
dst.write(data)
|
||||
await dst.drain()
|
||||
except Exception as e:
|
||||
pass
|
||||
finally:
|
||||
dst.close()
|
||||
await dst.wait_closed()
|
||||
|
||||
await asyncio.gather(
|
||||
forward(reader, dest_writer),
|
||||
forward(dest_reader, writer)
|
||||
)
|
||||
|
||||
async def start_server(source: str, destination: str):
|
||||
host, port = source.split(':')
|
||||
port = int(port)
|
||||
|
||||
server = await asyncio.start_server(
|
||||
lambda r, w: forward_traffic(r, w, destination),
|
||||
host,
|
||||
port
|
||||
)
|
||||
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
async def start_port_forwarding():
|
||||
if hasattr(Serve, 'forwarding_rules'):
|
||||
for rule in Serve.forwarding_rules:
|
||||
asyncio.create_task(start_server(rule.source, rule.destination))
|
||||
else:
|
||||
warn("No forwarding rules found in the configuration.")
|
||||
|
||||
@serve.get("/forward_status")
|
||||
async def get_forward_status():
|
||||
if hasattr(Serve, 'forwarding_rules'):
|
||||
return {"status": "active", "rules": Serve.forwarding_rules}
|
||||
else:
|
||||
return {"status": "inactive", "message": "No forwarding rules configured"}
|
||||
|
||||
# Add this to the end of your serve.py file
|
||||
asyncio.create_task(start_port_forwarding())
|
Loading…
Reference in a new issue