#!/usr/bin/python3

import contextlib
import glob
import json
import os
import subprocess
import sys
import tempfile
import xml.etree.ElementTree


@contextlib.contextmanager
def nbd_connect(image):
    for device in glob.glob("/dev/nbd*"):
        r = subprocess.run(["qemu-nbd", "--connect", device, "--read-only", image], check=False).returncode
        if r == 0:
            try:
                yield device
            finally:
                subprocess.run(["qemu-nbd", "--disconnect", device], check=True, stdout=subprocess.DEVNULL)
            break
    else:
        raise RuntimeError("no free network block device")


@contextlib.contextmanager
def mount_at(device, mountpoint, options=[]):
    opts = ",".join(["ro"] + options)
    subprocess.run(["mount", "-o", opts, device, mountpoint], check=True)
    try:
        yield mountpoint
    finally:
        subprocess.run(["umount", "--lazy", mountpoint], check=True)


@contextlib.contextmanager
def mount(device):
    with tempfile.TemporaryDirectory() as mountpoint:
        subprocess.run(["mount", "-o", "ro", device, mountpoint], check=True)
        try:
            yield mountpoint
        finally:
            subprocess.run(["umount", "--lazy", mountpoint], check=True)


def parse_environment_vars(s):
    r = {}
    for line in s.split("\n"):
        line = line.strip()
        if not line:
            continue
        if line[0] == '#':
            continue
        key, value = line.split("=", 1)
        r[key] = value.strip('"')
    return r


def parse_unit_files(s, expected_state):
    r = []
    for line in s.split("\n")[1:]:
        try:
            unit, state, *_ = line.split()
        except ValueError:
            pass
        if state != expected_state:
            continue
        r.append(unit)

    # deduplicate and sort
    r = list(set(r))
    r.sort()
    return r


def subprocess_check_output(argv, parse_fn=None):
    try:
        output = subprocess.check_output(argv, encoding="utf-8")
    except subprocess.CalledProcessError as e:
        sys.stderr.write(f"--- Output from {argv}:\n")
        sys.stderr.write(e.stdout)
        sys.stderr.write("\n--- End of the output\n")
        raise

    return parse_fn(output) if parse_fn else output


def read_image_format(device):
    qemu = subprocess_check_output(["qemu-img", "info", "--output=json", device], json.loads)
    return qemu["format"]


def read_partition(device, bootable, typ=None, start=0, size=0, type=None):
   blkid = subprocess_check_output(["blkid", "--output", "export", device], parse_environment_vars)
   return {
       "label": blkid.get("LABEL"), # doesn't exist for mbr
       "type": typ,
       "uuid": blkid.get("UUID"),
       "partuuid": blkid.get("PARTUUID"),
       "fstype": blkid.get("TYPE"),
       "bootable": bootable,
       "start": start,
       "size": size
   }


def read_partition_table(device):
    partitions = []
    try:
        sfdisk = subprocess_check_output(["sfdisk", "--json", device], json.loads)
    except subprocess.CalledProcessError:
        partitions.append(read_partition(device, False))
        return None, None, partitions
    else:
        ptable = sfdisk["partitiontable"]
        assert ptable["unit"] == "sectors"
        for p in ptable["partitions"]:
            partitions.append(read_partition(p["node"], p.get("bootable", False), p["type"], p["start"] * 512, p["size"] * 512))
        return ptable["label"], ptable["id"], partitions


def read_bootloader_type(device):
    with open(device, "rb") as f:
        if b"GRUB" in f.read(512):
            return "grub"
        else:
            return "unknown"


def read_boot_entries(boot_dir):
    entries = []
    for conf in glob.glob(f"{boot_dir}/loader/entries/*.conf"):
        with open(conf) as f:
           entries.append(dict(line.strip().split(" ", 1) for line in f))

    return sorted(entries, key=lambda e: e["title"])


def rpm_verify(tree):
    # cannot use `rpm --root` here, because rpm uses passwd from the host to
    # verify user and group ownership:
    #   https://github.com/rpm-software-management/rpm/issues/882
    rpm = subprocess.Popen(["chroot", tree, "rpm", "--verify", "--all"],
            stdout=subprocess.PIPE, encoding="utf-8")

    changed = {}
    missing = []
    for line in rpm.stdout:
        # format description in rpm(8), under `--verify`
        attrs = line[:9]
        if attrs == "missing  ":
            missing.append(line[12:].rstrip())
        else:
            changed[line[13:].rstrip()] = attrs

    # ignore return value, because it returns non-zero when it found changes
    rpm.wait()

    return {
        "missing": sorted(missing),
        "changed": changed
    }


def read_services(tree, state):
    return subprocess_check_output(["systemctl", f"--root={tree}", "list-unit-files"], (lambda s: parse_unit_files(s, state)))


def read_firewall_zone(tree):
    try:
        with open(f"{tree}/etc/firewalld/firewalld.conf") as f:
            conf = parse_environment_vars(f.read())
            default = conf["DefaultZone"]
    except FileNotFoundError:
        default = "public"

    r = []
    try:
        root = xml.etree.ElementTree.parse(f"{tree}/etc/firewalld/zones/{default}.xml").getroot()
    except FileNotFoundError:
        root = xml.etree.ElementTree.parse(f"{tree}/usr/lib/firewalld/zones/{default}.xml").getroot()

    for element in root.findall("service"):
        r.append(element.get("name"))

    return r


def append_filesystem(report, tree):
    if os.path.exists(f"{tree}/etc/os-release"):
        report["packages"] = sorted(subprocess_check_output(["rpm", "--root", tree, "-qa"], str.split))
        report["rpm-verify"] = rpm_verify(tree)

        with open(f"{tree}/etc/os-release") as f:
            report["os-release"] = parse_environment_vars(f.read())

        report["services-enabled"] = read_services(tree, "enabled")
        report["services-disabled"] = read_services(tree, "disabled")

        try:
            with open(f"{tree}/etc/hostname") as f:
                report["hostname"] = f.read().strip()
        except FileNotFoundError:
            pass

        try:
            report["timezone"] = os.path.basename(os.readlink(f"{tree}/etc/localtime"))
        except FileNotFoundError:
            pass

        try:
            report["firewall-enabled"] = read_firewall_zone(tree)
        except FileNotFoundError:
            pass

        try:
            with open(f"{tree}/etc/fstab") as f:
                report["fstab"] = sorted([line.split() for line in f.read().split("\n") if line and not line.startswith("#")])
        except FileNotFoundError:
            pass

        with open(f"{tree}/etc/passwd") as f:
            report["passwd"] = sorted(f.read().strip().split("\n"))

        with open(f"{tree}/etc/group") as f:
            report["groups"] = sorted(f.read().strip().split("\n"))

        if os.path.exists(f"{tree}/boot") and len(os.listdir(f"{tree}/boot")) > 0:
            assert "bootmenu" not in report
            try:
                with open(f"{tree}/boot/grub2/grubenv") as f:
                    report["boot-environment"] = parse_environment_vars(f.read())
            except FileNotFoundError:
                pass
            report["bootmenu"] = read_boot_entries(f"{tree}/boot")

    elif len(glob.glob(f"{tree}/vmlinuz-*")) > 0:
        assert "bootmenu" not in report
        with open(f"{tree}/grub2/grubenv") as f:
            report["boot-environment"] = parse_environment_vars(f.read())
        report["bootmenu"] = read_boot_entries(tree)
    elif len(glob.glob(f"{tree}/EFI")):
        print("EFI partition", file=sys.stderr)


def partition_is_esp(partition):
    return partition["type"] == "C12A7328-F81F-11D2-BA4B-00A0C93EC93B"


def find_esp(partitions):
    for i, p in enumerate(partitions):
        if partition_is_esp(p):
            return p, i
    return None, 0


def main(image):
    report = {}
    with nbd_connect(image) as device:
        report["image-format"] = read_image_format(image)
        report["bootloader"] = read_bootloader_type(device)
        report["partition-table"], report["partition-table-id"], report["partitions"] = read_partition_table(device)

        if report["partition-table"]:
            esp, esp_id = find_esp(report["partitions"])
            n_partitions = len(report["partitions"])
            for n in range(n_partitions):
                if report["partitions"][n]["fstype"]:
                    with mount(device + f"p{n + 1}") as tree:
                        if esp and os.path.exists(f"{tree}/boot/efi"):
                            with mount_at(device + f"p{esp_id + 1}", f"{tree}/boot/efi", options=['umask=077']):
                                append_filesystem(report, tree)
                        else:
                            append_filesystem(report, tree)
        else:
            with mount(device) as tree:
                append_filesystem(report, tree)

    return report

if __name__ == "__main__":
    subprocess.run(["modprobe", "nbd"], check=True)

    report = main(sys.argv[1])
    json.dump(report, sys.stdout, sort_keys=True, indent=2)
