#!/usr/bin/env python3

import subprocess
import requests
import argparse
import json
import random
import datetime
import os

LOG_FILE = '/var/log/vpn_rotation.txt'

PRIVACY_FRIENDLY_COUNTRIES = [
    'Finland',
    'Germany',
    'Iceland',
    'Netherlands',
    'Norway',
    'Sweden',
    'Switzerland'
]

TAILSCALE_ARGS = [
    '--exit-node-allow-lan-access',
    '--accept-dns',
    '--accept-routes'
]

def get_mullvad_info():
    """Fetch JSON info from Mullvad's 'am.i.mullvad.net/json' endpoint."""
    response = requests.get('https://am.i.mullvad.net/json')
    if response.status_code != 200:
        raise Exception("Could not fetch Mullvad info.")
    return response.json()

def get_current_exit_node():
    """
    Return the DNSName (e.g. 'de-ber-wg-001.mullvad.ts.net.') of whichever
    peer is currently acting as the exit node. Otherwise returns None.
    """
    result = subprocess.run(['tailscale', 'status', '--json'],
                            capture_output=True, text=True)
    if result.returncode != 0:
        raise Exception("Failed to get Tailscale status")

    status = json.loads(result.stdout)

    # 'Peer' is a dict with keys like "nodekey:fe8efdbab7c2..."
    peers = status.get('Peer', {})
    for peer_key, peer_data in peers.items():
        # If the node is currently the exit node, it should have "ExitNode": true
        if peer_data.get('ExitNode') is True:
            # Tailscale might return 'de-ber-wg-001.mullvad.ts.net.' with a trailing dot
            dns_name = peer_data.get('DNSName', '')
            dns_name = dns_name.rstrip('.')  # remove trailing dot
            return dns_name
    
    # If we don't find any peer with ExitNode = true, there's no exit node
    return None

def list_exit_nodes():
    """
    Return a dict {node_name: country} of all available Tailscale exit nodes
    based on 'tailscale exit-node list'.
    The output lines typically look like:
       <Star> <Name> <Country> <OS> ...
    Example line: 
       * de-dus-wg-001.mullvad.ts.net Germany linux ...
    """
    result = subprocess.run(['tailscale', 'exit-node', 'list'], capture_output=True, text=True)
    if result.returncode != 0:
        raise Exception("Failed to list Tailscale exit nodes")

    exit_nodes = {}
    for line in result.stdout.splitlines():
        parts = line.split()
        # Basic sanity check for lines that actually contain node info
        if len(parts) > 3:
            # parts[0] might be "*" if it's the current node
            # parts[1] is typically the FQDN (like "de-dus-wg-001.mullvad.ts.net")
            # parts[2] is the Country
            node_name = parts[1].strip()
            node_country = parts[2].strip()
            exit_nodes[node_name] = node_country

    return exit_nodes

def write_log(
    old_node=None, new_node=None,
    old_ip=None, new_ip=None,
    old_country=None, new_country=None
):
    """
    Appends a line to the log file reflecting a connection change.
    Example:
        2025.01.17 01:11:33 UTC · disconnected from de-dus-wg-001.mullvad.ts.net (Germany)
         · connected to at-vie-wg-001.mullvad.ts.net (Austria)
         · changed IP from 65.21.99.202 to 185.213.155.74
    If no old_node is specified, it indicates a fresh start (no disconnection).
    If no new_node is specified, it indicates a stop (only disconnection).
    """

    utc_time = datetime.datetime.utcnow().strftime('%Y.%m.%d %H:%M:%S UTC')
    log_parts = [utc_time]

    # If old_node was present, mention disconnect
    if old_node and old_country:
        log_parts.append(f"disconnected from {old_node} ({old_country})")

    # If new_node is present, mention connect
    if new_node and new_country:
        log_parts.append(f"connected to {new_node} ({new_country})")

    # If IPs changed
    if old_ip and new_ip and old_ip != new_ip:
        log_parts.append(f"changed IP from {old_ip} to {new_ip}")

    line = " · ".join(log_parts)

    # Append to file
    with open(LOG_FILE, 'a') as f:
        f.write(line + "\n")

def get_connection_history():
    """
    Returns an in-memory list of parsed log lines. 
    Each item looks like:
        {
            'timestamp': datetime_object,
            'disconnected_node': '...',
            'disconnected_country': '...',
            'connected_node': '...',
            'connected_country': '...',
            'old_ip': '...',
            'new_ip': '...'
        }
    """
    entries = []
    if not os.path.isfile(LOG_FILE):
        return entries

    with open(LOG_FILE, 'r') as f:
        lines = f.readlines()

    for line in lines:
        # Example line:
        # 2025.01.17 01:11:33 UTC · disconnected from de-dus-wg-001.mullvad.ts.net (Germany) · connected to ...
        # We'll parse step by step, mindful that each line can have different combos.
        parts = line.strip().split(" · ")
        if not parts:
            continue

        # parts[0] => '2025.01.17 01:11:33 UTC'
        timestamp_str = parts[0]
        connected_node = None
        connected_country = None
        disconnected_node = None
        disconnected_country = None
        old_ip = None
        new_ip = None

        # We parse the timestamp. We have '%Y.%m.%d %H:%M:%S UTC'
        try:
            dt = datetime.datetime.strptime(timestamp_str, '%Y.%m.%d %H:%M:%S UTC')
        except ValueError:
            continue  # If it doesn't parse, skip.

        for p in parts[1:]:
            p = p.strip()
            if p.startswith("disconnected from"):
                # e.g. "disconnected from de-dus-wg-001.mullvad.ts.net (Germany)"
                # We can split on "("
                disc_info = p.replace("disconnected from ", "")
                if "(" in disc_info and disc_info.endswith(")"):
                    node = disc_info.split(" (")[0]
                    country = disc_info.split(" (")[1].replace(")", "")
                    disconnected_node = node
                    disconnected_country = country
            elif p.startswith("connected to"):
                # e.g. "connected to at-vie-wg-001.mullvad.ts.net (Austria)"
                conn_info = p.replace("connected to ", "")
                if "(" in conn_info and conn_info.endswith(")"):
                    node = conn_info.split(" (")[0]
                    country = conn_info.split(" (")[1].replace(")", "")
                    connected_node = node
                    connected_country = country
            elif p.startswith("changed IP from"):
                # e.g. "changed IP from 65.21.99.202 to 185.213.155.74"
                # We'll split on spaces
                # changed IP from 65.21.99.202 to 185.213.155.74
                # index:     0     1  2        3           4
                ip_parts = p.split()
                if len(ip_parts) >= 5:
                    old_ip = ip_parts[3]
                    new_ip = ip_parts[5]

        entries.append({
            'timestamp': dt,
            'disconnected_node': disconnected_node,
            'disconnected_country': disconnected_country,
            'connected_node': connected_node,
            'connected_country': connected_country,
            'old_ip': old_ip,
            'new_ip': new_ip
        })

    return entries

def get_last_connection_entry():
    """
    Parse the log and return the last entry that actually
    has a 'connected_node', which indicates a stable connection.
    """
    history = get_connection_history()
    # Go in reverse chronological order
    for entry in reversed(history):
        if entry['connected_node']:
            return entry
    return None

def set_exit_node(exit_node):
    """
    Generic helper to set Tailscale exit node to 'exit_node'.
    Returns (old_ip, new_ip, old_node, new_node, old_country, new_country)
    """
    # Get old info for logging
    old_info = get_mullvad_info()
    old_ip = old_info.get('ip')
    old_country = old_info.get('country')
    old_node = get_current_exit_node()  # might be None

    cmd = ['tailscale', 'set', f'--exit-node={exit_node}'] + TAILSCALE_ARGS
    subprocess.run(cmd, check=True)

    # Verify the new node
    new_info = get_mullvad_info()
    new_ip = new_info.get('ip')
    new_country = new_info.get('country')
    new_node = exit_node

    return old_ip, new_ip, old_node, new_node, old_country, new_country

def unset_exit_node():
    """
    Unset Tailscale exit node.
    """
    # For logging, we still want old IP + new IP. The 'new' IP after unsetting might revert to local.
    old_info = get_mullvad_info()
    old_ip = old_info.get('ip')
    old_country = old_info.get('country')
    old_node = get_current_exit_node()

    cmd = ['tailscale', 'set', '--exit-node='] + TAILSCALE_ARGS
    subprocess.run(cmd, check=True)

    # Now see if the IP changed
    new_info = get_mullvad_info()
    new_ip = new_info.get('ip')
    new_country = new_info.get('country')
    new_node = None

    write_log(old_node, new_node, old_ip, new_ip, old_country, new_country)
    print("Exit node unset successfully!")

def start_exit_node():
    """
    Start the exit node if none is currently set.
    Otherwise, report what is already set.
    """
    current_exit_node = get_current_exit_node()
    if current_exit_node:
        print(f"Already connected to exit node: {current_exit_node}")
    else:
        # Use the default "tailscale exit-node suggest" approach
        result = subprocess.run(['tailscale', 'exit-node', 'suggest'], capture_output=True, text=True)
        if result.returncode != 0:
            raise Exception("Failed to run 'tailscale exit-node suggest'")

        suggested = ''
        for line in result.stdout.splitlines():
            if 'Suggested exit node' in line:
                suggested = line.split(': ')[1].strip()
                break

        if not suggested:
            raise Exception("No suggested exit node found.")

        (old_ip, new_ip,
         old_node, new_node,
         old_country, new_country) = set_exit_node(suggested)

        # Log it
        write_log(old_node, new_node, old_ip, new_ip, old_country, new_country)
        print(f"Exit node set successfully to {new_node}")

def set_random_privacy_friendly_exit_node():
    """
    Pick a random node from PRIVACY_FRIENDLY_COUNTRIES and set it.
    """
    # Filter exit nodes by known privacy-friendly countries
    nodes = list_exit_nodes()
    # nodes is dict {node_name: country}
    pf_nodes = [n for n, c in nodes.items() if c in PRIVACY_FRIENDLY_COUNTRIES]

    if not pf_nodes:
        raise Exception("No privacy-friendly exit nodes available")

    exit_node = random.choice(pf_nodes)
    (old_ip, new_ip,
     old_node, new_node,
     old_country, new_country) = set_exit_node(exit_node)

    # Log
    write_log(old_node, new_node, old_ip, new_ip, old_country, new_country)
    print(f"Selected random privacy-friendly exit node: {exit_node}")
    print("Exit node set successfully!")

def set_random_exit_node_in_country(country_input):
    """
    Pick a random node in the given (case-insensitive) country_input.
    Then set the exit node to that node.
    """
    country_input_normalized = country_input.strip().lower()

    all_nodes = list_exit_nodes()
    # Filter nodes in the user-requested country
    country_nodes = [
        node_name for node_name, node_country in all_nodes.items()
        if node_country.lower() == country_input_normalized
    ]

    if not country_nodes:
        raise Exception(f"No exit nodes found in {country_input}.")

    exit_node = random.choice(country_nodes)

    (old_ip, new_ip,
     old_node, new_node,
     old_country, new_country) = set_exit_node(exit_node)

    # Log
    write_log(old_node, new_node, old_ip, new_ip, old_country, new_country)
    print(f"Selected random exit node in {country_input.title()}: {exit_node}")
    print("Exit node set successfully!")

def get_status():
    """
    Print current connection status:
    - Whether connected or not
    - Current exit node and IP
    - Country of that exit node
    - How long it has been connected to that exit node (based on the last log entry)
    """
    current_node = get_current_exit_node()
    if not current_node:
        print("No exit node is currently set.")
        return

    # Current IP & country
    info = get_mullvad_info()
    current_ip = info.get('ip')
    current_country = info.get('country')

    # Find the last time we connected to this node in the log
    history = get_connection_history()
    # We look from the end backwards for an entry that connected to the current_node
    connected_since = None
    for entry in reversed(history):
        if entry['connected_node'] == current_node:
            connected_since = entry['timestamp']
            break

    # We'll compute a "connected for X minutes/hours/days" style message
    if connected_since:
        now_utc = datetime.datetime.utcnow()
        delta = now_utc - connected_since
        # For user-friendliness, just show something like 1h 12m, or 2d 3h
        # We'll do a simple approach:
        total_seconds = int(delta.total_seconds())
        days = total_seconds // 86400
        hours = (total_seconds % 86400) // 3600
        minutes = (total_seconds % 3600) // 60

        duration_parts = []
        if days > 0:
            duration_parts.append(f"{days}d")
        if hours > 0:
            duration_parts.append(f"{hours}h")
        if minutes > 0:
            duration_parts.append(f"{minutes}m")
        if not duration_parts:
            duration_parts.append("0m")  # means less than 1 minute

        duration_str = " ".join(duration_parts)
        print(f"Currently connected to: {current_node} ({current_country})")
        print(f"IP: {current_ip}")
        print(f"Connected for: {duration_str}")
    else:
        # If we never found it in the log, it's presumably a brand new connection
        print(f"Currently connected to: {current_node} ({current_country})")
        print(f"IP: {current_ip}")
        print("Connected for: <unknown>, no log entry found.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Manage VPN exit nodes.')
    parser.add_argument(
        'action',
        choices=['start', 'stop', 'new', 'shh', 'to', 'status'],
        help='Action to perform: start, stop, new, shh, to <country>, or status'
    )
    parser.add_argument(
        'country',
        nargs='?',
        default=None,
        help='Country name (used only with "to" mode).'
    )

    args = parser.parse_args()

    if args.action == 'start':
        start_exit_node()
    elif args.action == 'stop':
        unset_exit_node()
    elif args.action == 'new':
        # This calls set_exit_node() using the Tailscale "suggest" approach
        # from the original script
        result = subprocess.run(['tailscale', 'exit-node', 'suggest'], capture_output=True, text=True)
        if result.returncode != 0:
            raise Exception("Failed to run 'tailscale exit-node suggest'")

        exit_node = ''
        for line in result.stdout.splitlines():
            if 'Suggested exit node' in line:
                exit_node = line.split(': ')[1].strip()
                break

        if not exit_node:
            raise Exception("No suggested exit node found.")

        (old_ip, new_ip,
         old_node, new_node,
         old_country, new_country) = set_exit_node(exit_node)
        write_log(old_node, new_node, old_ip, new_ip, old_country, new_country)
        print(f"Exit node set to suggested node: {new_node}")

    elif args.action == 'shh':
        # Random privacy-friendly
        set_random_privacy_friendly_exit_node()

    elif args.action == 'to':
        # "vpn to sweden" => pick a random node in Sweden
        if not args.country:
            raise Exception("You must specify a country. e.g. vpn to sweden")
        set_random_exit_node_in_country(args.country)

    elif args.action == 'status':
        get_status()