#!/usr/bin/python3

import json
import os
import subprocess
import sys
import traceback

import requests

# Utility functions


def sh(cmd, debug=os.environ.get("GAN_DEBUG")):
    if debug:
        print(f"+ {cmd}")
    r = (
        subprocess.run(cmd, shell=True, capture_output=True, check=True)
        .stdout.decode()
        .strip()
    )
    if debug and r:
        print("> " + "> ".join(r.splitlines()))
    return r


def api(url):
    base_url = "https://api.grid5000.fr/stable/"
    r = requests.get(base_url + url)
    r.raise_for_status()
    return r.json()


def get_node_site():
    nodesite = sh("hostname -f").split(".")[0:2]
    if len(nodesite) != 2:
        raise RuntimeError("Cannot retrieve node and site from 'hostname -f'")
    node, site = nodesite
    # Drop kavlan suffix from the node
    return "-".join(node.split("-")[0:2]), site


def get_node_kavlan(node, site):
    try:
        return [
            item["vlan"]
            for item in api(f"sites/{site}/vlans/nodes")["items"]
            if item["uid"] == f"{node}.{site}.grid5000.fr"
        ][0]
    except (IndexError, requests.exceptions.HTTPError):
        traceback.print_exc()
        print(
            "Cannot retrieve VLAN information for node "
            f"{node}.{site}.grid5000.fr",
            file=sys.stderr,
        )
        sys.exit(1)


def get_node_user(node, site):
    try:
        n = ".".join(get_node_site()) + ".grid5000.fr"
        reservations = api(f"sites/{site}/status")["nodes"][n]["reservations"]
        return [
            item["user_uid"]
            for item in reservations
            if item["state"] == "running"
        ][0]
    except (IndexError, requests.exceptions.HTTPError):
        traceback.print_exc()
        print(
            "Cannot retrieve user for node " f"{node}.{site}.grid5000.fr",
            file=sys.stderr,
        )
        sys.exit(1)


def get_user_pubkey(user):
    try:
        return api(f"users/users/{user}")["ssh_public_key"].strip()
    except requests.exceptions.HTTPError:
        traceback.print_exc()
        print(f"Cannot retrieve user {user} public key", file=sys.stderr)
        sys.exit(1)


# Start of the script

print("Getting node and user information from Grid'5000 API")
node, site = get_node_site()
user = get_node_user(node, site)
pubkey = get_user_pubkey(user)

print(
    "Checking deployed environment and postinstall to ensure that "
    "it is supported by the script"
)
try:
    sh("grep -E '^debian11-[a-z0-9]+-big' /etc/grid5000/release")
except subprocess.CalledProcessError:
    print("A debian11-*-big environment must be used", file=sys.stderr)
    sys.exit(1)
try:
    sh(
        r"grep "
        r"'\"g5k-postinstall --net debian --net hpc --fstab nfs "
        r"--restrict-user current --disk-aliases\"' "
        "/etc/grid5000/postinstall"
    )
except subprocess.CalledProcessError:
    print("Default postinstall options must be used", file=sys.stderr)
    sys.exit(1)

print("Checking if node is in a KaVLAN")
node_kavlan = get_node_kavlan(node, site)
if node_kavlan == "DEFAULT" or not int(node_kavlan) >= 4:
    print(
        f"Node {node}.{site}.grid5000.fr is not in a KaVLAN(routed or global)",
        file=sys.stderr,
    )
    sys.exit(1)

print("Configuring postfix")
sh(f"postconf relayhost=mail.{site}.grid5000.fr")
sh(f"postconf myhostname={node}.{site}.grid5000.fr")

print("Notifying g5k-armor-node is used")
sh(
    'echo "g5k-armor-node running on $(hostname)" | mail -s '
    '"g5k-armor-node running on $(hostname)" root'
)
sh(
    "logger -n syslog --tcp -P514 -p authpriv.info "
    f'"g5k-armor-node running on $(hostname), started by {user}"'
)

print("Upgrading system packages")
os.environ["DEBIAN_FRONTEND"] = "noninteractive"
sh(
    "apt-get update && "
    "apt-get -y install needrestart && "
    "apt-get -y full-upgrade"
)
for service in sh(
    "needrestart -b -r a | grep NEEDRESTART-SVC | awk '{print $2}'"
).split("\n"):
    if service:
        try:
            sh(f"systemctl restart {service}")
        except subprocess.CalledProcessError:
            # Restarting a service may fail,
            # but this cause restart_needed to be True
            pass

restart_needed = any(
    status != "1"
    for status in sh(
        "needrestart -b -r a | grep -e 'NEEDRESTART-.*STA:' | awk '{print $2}'"
    ).split("\n")
)
if restart_needed:
    print("ERROR: A reboot is needed to complete the upgrade.")
    print(
        "Please reboot the node (using the 'reboot' command), "
        "and after the reboot,run the script again"
    )
    exit(1)

print("Disabling root password and SSH access")
sh("passwd -l root")
sh("rm -f -r /root/.ssh/")
sh("sed -i '/.*PermitRootLogin.*/d' /etc/ssh/sshd_config")
sh("echo 'PermitRootLogin no' >> /etc/ssh/sshd_config")
sh("systemctl reload sshd.service")

print("Resetting SSH host (server) keys and restarting SSH server")
fp_out = sh(
    "rm -f /etc/ssh/ssh_host_* && dpkg-reconfigure openssh-server 2>&1"
)

print(
    "Removing access to /home and other NFS storage "
    "(with the exception of /grid5000)"
)
sh("systemctl stop autofs")
for mount in sh("mount -l | grep -e '^nfs'").split("\n"):
    if mount and mount.split()[2] != "/grid5000":
        sh(f"umount -l {mount.split(' ')[0]}")
files = [
    "/etc/auto.master.d/home.autofs",
    "/etc/auto.home",
    "/etc/auto.master.d/storage.autofs",
    "/etc/auto.storage",
    "/etc/auto.master.d/guix.autofs",
    "/etc/auto.guix",
    "/etc/profile.d/guix.sh",
]
sh("rm -f %s" % " ".join(files))
sh("systemctl start autofs")

print("Disabling LDAP authentication")
sh("systemctl stop nslcd nscd")
sh("apt-get -y purge libnss-ldapd libpam-ldapd nslcd nscd")

print(
    f"Creating a local user account for {user} with sudo privilege, "
    "adding the user's SSH public key to authorized_keys"
)
sh(f"useradd --shell /bin/bash {user}")
sh(f"echo '{user} ALL=(ALL:ALL) NOPASSWD: ALL' > /etc/sudoers.d/{user}")
sh(f"mkdir -p /home/{user}/.ssh")
with open(f"/home/{user}/.ssh/authorized_keys", "w+") as f:
    f.write(pubkey)

print(
    "Setting up firewall (INPUT restricted to access-{north,south} and ICMP)"
)
sh("iptables -t nat -F && iptables -t mangle -F && iptables -F && iptables -X")
sh("iptables -A INPUT -i lo -j ACCEPT")
sh("iptables -A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT")
sh("iptables -A INPUT -p icmp -j ACCEPT")
sh("iptables -A INPUT -s access-north.grid5000.fr -p tcp --dport 22 -j ACCEPT")
sh("iptables -A INPUT -s access-south.grid5000.fr -p tcp --dport 22 -j ACCEPT")
sh("iptables -P INPUT DROP")
sh("ip6tables -P INPUT DROP")
sh("ip6tables -P OUTPUT DROP")
sh("ip6tables -P FORWARD DROP")
sh("apt-get -y install iptables-persistent")
sh("netfilter-persistent save")

print("Unmounting /tmp and remount it under tmpfs")
sh(r"sed -E -i '/^\S+\s+\/tmp\s.*/d' /etc/fstab")
sh("systemctl daemon-reload")
sh(
    "systemctl stop haveged.service ModemManager.service ntp.service "
    "systemd-logind.service"
)
tmp_dev = [
    part["name"]
    for bd in json.loads(sh("lsblk -J"))["blockdevices"]
    for part in bd.get("children", [])
    if part.get("mountpoint", "") == "/tmp"
].pop()
sh(f"umount /dev/{tmp_dev}")
sh(
    "echo "
    "'tmpfs   /tmp         tmpfs   rw,nodev,nosuid,size=2G          0  0' "
    ">> /etc/fstab"
)
sh("mount /tmp")
sh(
    "systemctl start haveged.service ModemManager.service ntp.service "
    "systemd-logind.service"
)

print("Setting up encrypted local storage")
sh("apt update && apt -y install cryptsetup")
sh(
    "dd bs=512 count=4 if=/dev/random of=/dev/stdout iflag=fullblock | base64 "
    "> /run/user/0/key"
)
sh("sed -i '/^global_filter =.*/d' /etc/lvm/lvm.conf")
sh("dmsetup remove_all -f")
sh(f"pvcreate -f /dev/{tmp_dev}")
# Check if vg exists already, see bug 15643
# This only fixes the case where a vg named 'vg' already exists
# (created by a previous g5k-armor-node run)
# Fixing the general case of dirty secondary disks should be addressed
# by the bugs 12799 and 11073
existing_vg = sh(
    "vgdisplay vg 2>/dev/null " "| grep 'VG Name' | awk '{print $3}'"
).strip()
if existing_vg == "vg":
    sh("vgremove -ff vg")
sh(f"vgcreate vg /dev/{tmp_dev}")
for disk in json.loads(sh("lsblk -J")).get("blockdevices", []):
    if disk.get("type", "") == "disk" and not any(
        part.get("mountpoint", "") == "/" for part in disk.get("children", [])
    ):
        sh(f"wipefs -af /dev/{disk['name']}")
        sh(f"pvcreate -f /dev/{disk['name']}")
        sh(f"vgextend vg /dev/{disk['name']}")
sh("lvcreate -y -l100%FREE -n data vg")
sh("cryptsetup --batch-mode luksFormat /dev/mapper/vg-data /run/user/0/key")
sh(
    "cryptsetup luksOpen --key-file /run/user/0/key "
    "/dev/mapper/vg-data encrypted"
)
sh("mkfs.ext4 /dev/mapper/encrypted")

print("Copying user's SSH key to encrypted home")
sh("mount /dev/mapper/encrypted /mnt")
sh(f"cp -a /home/{user}/.ssh /mnt")
sh("umount /mnt")

print("Mounting encrypted storage as user's home")
sh(f"mount /dev/mapper/encrypted /home/{user}")
sh(f"chown -R {user}:{user} /home/{user}")

print("\n------------------------------------------------------------")
print("SSH server fingerprints:")
print("------------------------------------------------------------")
print(
    "\n".join(
        "- " + fp.split()[3] + " " + fp.split()[1]
        for fp in fp_out.split("\n")
        if f".{site}.grid5000.fr" in fp
    )
)
print("------------------------------------------------------------")

print("\n------------------------------------------------------------")
print("\033[1mPlease keep a secure copy of the encrypted storage key:\033[0m")
print("------------------------------------------------------------")
print(sh("cat /run/user/0/key"))
print("------------------------------------------------------------")
sh("rm /run/user/0/key")

print("\033[1mSetup completed successfully!\033[0m")
# TODO: Close connection to node