#!/usr/bin/env python3
import argparse
import os
import shlex
import subprocess
import sys
from pathlib import Path

from fmt import green, red, yellow

ROOT = Path(__file__).parent.parent
TOOLS = ROOT / "tools"
TARGET_RELEASE = ROOT / "target" / "release"


def relative(path):
    return str(Path(path).relative_to(Path.cwd()))


def discover(roots):
    def inspect(path):
        with open(path, "r") as f:
            run_lines = [line for line in f if "RUN:" in line]
        if not run_lines:
            return
        run_lines = [line.split("RUN:")[1] for line in run_lines]
        yield (path, {"RUN": run_lines})

    for root in roots:
        root = Path(root).resolve()
        if root.is_file():
            yield from inspect(root)
            continue

        for root, dirs, files in os.walk(root):
            root = Path(root)
            for fname in files:
                path = root / fname
                yield from inspect(path)


def split_seq(seq, split):
    chunk = []
    for item in seq:
        if item == split:
            yield chunk
            chunk = []
        else:
            chunk.append(item)
    yield chunk


def run_test(source, pipeline, verbose=False, PATH=None):
    s = shlex.shlex(pipeline, posix=True, punctuation_chars=True)
    s.whitespace_split = True
    pipeline = list(split_seq(s, "|"))
    env = {}
    if PATH:
        env["PATH"] = PATH

    print(f"{relative(source):<72}", end="", flush=True)
    next_stdin = subprocess.PIPE
    first_process = None
    processes = []
    stderr_r, stderr = os.pipe()
    for command in reversed(pipeline):
        command = [word.replace("%s", str(source.resolve())) for word in command]
        process = subprocess.Popen(
            command, stdin=subprocess.PIPE, stderr=stderr, stdout=next_stdin, env=env
        )
        next_stdin = process.stdin
        processes.append(process)
    processes.reverse()
    processes[0].communicate()
    any_failed = False
    for process in processes:
        process.stdin.close()
        process.wait()
        any_failed |= process.returncode != 0
    os.close(stderr)
    if any_failed:
        print(red("FAIL"))
        if verbose:
            print(yellow("stdout:"))
            for line in processes[-1].stdout:
                sys.stdout.buffer.write(line)
            print()
            print(yellow("stderr:"))
            while chunk := os.read(stderr_r, 1024):
                sys.stdout.buffer.write(chunk)
            print()
        return False
    else:
        print(green("PASS"))
        return True


def report(tests):
    passes = [path for path, status in tests.items() if status == "PASS"]
    failures = [path for path, status in tests.items() if status == "FAIL"]
    print(
        f"""
PASS: {len(passes)}
FAIL: {len(failures)}
"""
    )
    if len(failures):
        print(yellow("Failures:"))
        print()
        for failure in failures:
            print(f"    {relative(failure)}")
        print(red("\nFAILED\n"))
        return 1
    else:
        print(green("PASSED\n"))
        return 0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("tests", nargs="+")
    parser.add_argument("-v", "--verbose", action="store_true")
    args = parser.parse_args()
    tests = {}
    PATH = ":".join([os.getenv("PATH"), str(TOOLS), str(TARGET_RELEASE)])
    for path, directives in discover(args.tests):
        tests[path] = "FAIL"
        for run_line in directives["RUN"]:
            if run_test(path, run_line, args.verbose, PATH):
                tests[path] = "PASS"
    status = report(tests)
    sys.exit(status)


if __name__ == "__main__":
    main()
