diff --git a/sijapi/helpers/start.py b/sijapi/helpers/start.py index 0f218fc..c41c54f 100644 --- a/sijapi/helpers/start.py +++ b/sijapi/helpers/start.py @@ -6,6 +6,8 @@ from pathlib import Path import logging import subprocess import os +import argparse +import sys logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -31,8 +33,7 @@ def check_server(ip, port, ts_id): address = f"http://{ip}:{port}/id" try: response = requests.get(address, timeout=5) - response_text = response.text.strip().strip('"') - + response_text = response.text.strip().strip('"') return response.status_code == 200 and response_text == ts_id except requests.RequestException as e: logging.error(f"Error checking server {ts_id}: {str(e)}") @@ -70,7 +71,7 @@ def start_local_server(server): def start_remote_server(server): ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - + try: ssh.connect( server['ts_ip'], @@ -79,7 +80,7 @@ def start_remote_server(server): password=server['ssh_pass'], timeout=10 ) - + status, output, error = execute_ssh_command(ssh, f"{server['tmux']} has-session -t sijapi 2>/dev/null && echo 'exists' || echo 'not exists'") if output == 'exists': logging.info(f"sijapi session already exists on {server['ts_id']}") @@ -87,24 +88,93 @@ def start_remote_server(server): command = f"{server['tmux']} new-session -d -s sijapi 'cd {server['path']} && {server['conda_env']}/bin/python -m sijapi'" status, output, error = execute_ssh_command(ssh, command) - + if status == 0: logging.info(f"Successfully started sijapi session on {server['ts_id']}") else: logging.error(f"Failed to start sijapi session on {server['ts_id']}. Error: {error}") - + except paramiko.SSHException as e: logging.error(f"Failed to connect to {server['ts_id']}: {str(e)}") finally: ssh.close() +def kill_local_server(): + try: + if is_local_tmux_session_running('sijapi'): + subprocess.run(['tmux', 'kill-session', '-t', 'sijapi'], check=True) + logging.info("Killed local sijapi tmux session.") + else: + logging.info("No local sijapi tmux session to kill.") + except subprocess.CalledProcessError as e: + logging.error(f"Failed to kill local sijapi tmux session. Error: {e}") + +def kill_remote_server(server): + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + try: + ssh.connect( + server['ts_ip'], + port=server['ssh_port'], + username=server['ssh_user'], + password=server['ssh_pass'], + timeout=10 + ) + + command = f"{server['tmux']} kill-session -t sijapi" + status, output, error = execute_ssh_command(ssh, command) + + if status == 0: + logging.info(f"Successfully killed sijapi session on {server['ts_id']}") + else: + logging.error(f"Failed to kill sijapi session on {server['ts_id']}. Error: {error}") + + except paramiko.SSHException as e: + logging.error(f"Failed to connect to {server['ts_id']}: {str(e)}") + finally: + ssh.close() def main(): load_env() config = load_config() pool = config['POOL'] local_ts_id = os.environ.get('TS_ID') - + + parser = argparse.ArgumentParser(description='Manage sijapi servers') + parser.add_argument('--kill', action='store_true', help='Kill the local sijapi tmux session') + parser.add_argument('--restart', action='store_true', help='Restart the local sijapi tmux session') + parser.add_argument('--all', action='store_true', help='Apply the action to all servers') + + args = parser.parse_args() + + if args.kill: + if args.all: + for server in pool: + if server['ts_id'] == local_ts_id: + kill_local_server() + else: + kill_remote_server(server) + else: + kill_local_server() + sys.exit(0) + + if args.restart: + if args.all: + for server in pool: + if server['ts_id'] == local_ts_id: + kill_local_server() + start_local_server(server) + else: + kill_remote_server(server) + start_remote_server(server) + else: + kill_local_server() + local_server = next(server for server in pool if server['ts_id'] == local_ts_id) + start_local_server(local_server) + sys.exit(0) + + # If no specific arguments, run the original script for server in pool: if check_server(server['ts_ip'], server['app_port'], server['ts_id']): logging.info(f"{server['ts_id']} is running and responding correctly.") @@ -114,8 +184,9 @@ def main(): start_local_server(server) else: start_remote_server(server) - + time.sleep(1) if __name__ == "__main__": - main() \ No newline at end of file + main() +