#!/usr/bin/env python3

"""
Wrapper script for running the Rumur model checker.

This script is intended to be installed alongside the `rumur` binary from the
Rumur model checker. It can then be used to quickly generate and run a model, as
an alternative to having to run the model generation, compilation and execution
steps manually.
"""

import os
import platform
import re
import shutil
import subprocess as sp
import sys
import tempfile
from pathlib import Path

# C compiler
CC = shutil.which(os.environ.get("CC", "cc"))


def categorise(cc):
    """
    Determine the vendor of a given C compiler
    """

    # Create a temporary area to compile a test file
    with tempfile.TemporaryDirectory() as t:
        tmp = Path(t)

        # Setup the test file
        src = tmp / "test.c"
        with open(str(src), "wt", encoding="utf-8") as f:
            f.write(
                """\
        #include <stdio.h>
        #include <stdlib.h>
        int main(void) {
        #ifdef __clang__
          printf("clang\\n");
        #elif defined(__GNUC__)
          printf("gcc\\n");
        #else
          printf("unknown\\n");
        #endif
          return EXIT_SUCCESS;
        }
        """
            )

        categorisation = "unknown"

        # Compile it
        aout = tmp / "a.out"
        cc_ret = sp.call(
            [cc, "-o", str(aout), str(src)],
            universal_newlines=True,
            stdout=sp.DEVNULL,
            stderr=sp.DEVNULL,
        )

        # Run it
        if cc_ret == 0:
            try:
                categorisation = sp.check_output(
                    [str(aout)], universal_newlines=True
                ).strip()
            except sp.CalledProcessError:
                pass

    return categorisation


def supports(flag):
    """check whether the compiler supports a given command line flag"""

    # a trivial program to ask it to compile
    program = "int main(void) { return 0; }"

    # compile it
    p = sp.Popen(
        [CC, "-o", os.devnull, "-x", "c", "-", flag], stderr=sp.DEVNULL, stdin=sp.PIPE
    )
    p.communicate(program.encode("utf-8", "replace"))

    # check whether compilation succeeded
    return p.returncode == 0


def needs_libatomic():
    """check whether the compiler needs -latomic for a double-word
    compare-and-swap"""

    # CAS program to ask it to compile
    program = """
#include <stdbool.h>
#include <stdint.h>

// replicate what is in ../resources/header.c

#define THREADS 2

#if __SIZEOF_POINTER__ <= 4
  typedef uint64_t dword_t;
#elif __SIZEOF_POINTER__ <= 8
  typedef unsigned __int128 dword_t;
#else
  #error "unexpected pointer size; what scalar type to use for dword_t?"
#endif

static dword_t atomic_read(dword_t *p) {

  if (THREADS == 1) {
    return *p;
  }

#if defined(__x86_64__) || defined(__i386__) || \
    (defined(__aarch64__) && defined(__GNUC__) && !defined(__clang__))
  /* x86-64: MOV is not guaranteed to be atomic on 128-bit naturally aligned
   *   memory. The way to work around this is apparently the following
   *   degenerate CMPXCHG16B.
   * i386: __atomic_load_n emits code calling a libatomic function that takes a
   *   lock, making this no longer lock free. Force a CMPXCHG8B by using the
   *   __sync built-in instead.
   * ARM64: __atomic_load_n emits code calling a libatomic function that takes a
   *   lock, making this no longer lock free. Force a CASP by using the __sync
   *   built-in instead.
   *
   * XXX: the obvious (irrelevant) literal to use here is “0” but this triggers
   * a GCC bug on ARM, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=114310.
   */
  return __sync_val_compare_and_swap(p, 1, 1);
#endif

  return __atomic_load_n(p, __ATOMIC_ACQUIRE);
}

static void atomic_write(dword_t *p, dword_t v) {

  if (THREADS == 1) {
    *p = v;
    return;
  }

#if defined(__x86_64__) || defined(__i386__) || \
    (defined(__aarch64__) && defined(__GNUC__) && !defined(__clang__))
  /* As explained above, we need some extra gymnastics to avoid a call to
   * libatomic on x86-64, i386, and ARM64.
   */
  dword_t expected;
  dword_t old = 0;
  do {
    expected = old;
    old = __sync_val_compare_and_swap(p, expected, v);
  } while (expected != old);
  return;
#endif

  __atomic_store_n(p, v, __ATOMIC_RELEASE);
}

static bool atomic_cas(dword_t *p, dword_t expected, dword_t new) {

  if (THREADS == 1) {
    if (*p == expected) {
      *p = new;
      return true;
    }
    return false;
  }

#if defined(__x86_64__) || defined(__i386__) || \
    (defined(__aarch64__) && defined(__GNUC__) && !defined(__clang__))
  /* Make GCC >= 7.1 emit cmpxchg on x86-64 and i386. See
   * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80878.
   * Make GCC emit a CASP on ARM64 (undocumented?).
   */
  return __sync_bool_compare_and_swap(p, expected, new);
#endif

  return __atomic_compare_exchange_n(p, &expected, new, false, __ATOMIC_ACQ_REL,
                                     __ATOMIC_RELAXED);
}

static dword_t atomic_cas_val(dword_t *p, dword_t expected, dword_t new) {

  if (THREADS == 1) {
    dword_t old = *p;
    if (old == expected) {
      *p = new;
    }
    return old;
  }

#if defined(__x86_64__) || defined(__i386__) || \
    (defined(__aarch64__) && defined(__GNUC__) && !defined(__clang__))
  /* Make GCC >= 7.1 emit cmpxchg on x86-64 and i386. See
   * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80878.
   * Make GCC emit a CASP on ARM64 (undocumented?).
   */
  return __sync_val_compare_and_swap(p, expected, new);
#endif


  (void)__atomic_compare_exchange_n(p, &expected, new, false, __ATOMIC_ACQ_REL,
                                    __ATOMIC_ACQUIRE);
  return expected;
}

int main(void) {
  dword_t target = 0;

  target = atomic_read(&target);

  atomic_write(&target, 42);

  atomic_cas(&target, 42, 0);

  return (int)atomic_cas_val(&target, 0, 42);
}
"""

    # compile it
    args = [CC, "-x", "c", "-std=c11", "-", "-o", os.devnull]
    if supports("-march=native"):
        args.append("-march=native")
    if supports("-mcx16"):
        args.append("-mcx16")
    p = sp.Popen(args, stderr=sp.DEVNULL, stdin=sp.PIPE)
    p.communicate(program.encode("utf-8", "replace"))

    # check whether compilation succeeded
    return p.returncode != 0


def optimisation_flags():
    """C compiler optimisation command line options for this platform"""

    flags = ["-O3"]

    # optimise code for the current host architecture
    if supports("-march=native"):
        flags.append("-march=native")

    # optimise code for the current host CPU
    if supports("-mtune=native"):
        flags.append("-mtune=native")

    # enable link-time optimisation
    if supports("-flto"):
        flags.append("-flto")

    cc_vendor = categorise(CC)

    # allow GCC to perform more advanced interprocedural optimisations
    if cc_vendor == "gcc":
        flags.append("-fwhole-program")

    # on platforms that need it made explicit, enable CMPXCHG16B
    if supports("-mcx16"):
        flags.append("-mcx16")

    return flags


def has_no_la57():
    """
    does our hardware lack support for Intel 5-level paging?
    """

    # if we are not on an x86-64 platform, this is irrelevant
    if platform.machine() not in ("amd64", "x86_64"):
        return False

    # read cpuinfo, looking for la57 as a supported flag by the CPU
    try:
        with open("/proc/cpuinfo", "rt", encoding="utf-8") as f:
            for line in f:
                flags = re.match(r"flags\s*:\s*(?P<flags>.*)$", line)
                if flags is not None:
                    # found the flags field; now is it missing LA57?
                    return re.search(r"\bla57\b", flags.group("flags")) is None

    except (FileNotFoundError, PermissionError):
        # procfs is unavailable
        return False

    # if we read the entire cpuinfo and did not find the flags field,
    # conservatively assume we may support LA57
    return False


def main(args):

    # Find the Rumur binary
    rumur_bin = shutil.which(
        os.path.join(os.path.abspath(os.path.dirname(__file__)), "rumur")
    )
    if rumur_bin is None:
        rumur_bin = shutil.which("rumur")
    if rumur_bin is None:
        sys.stderr.write("rumur binary not found\n")
        return -1

    # if the user asked for help or version information, run Rumur directly
    for arg in args[1:]:
        if arg.startswith("-h") or arg.startswith("--h") or arg.startswith("--vers"):
            os.execv(rumur_bin, [rumur_bin] + args[1:])

    if CC is None:
        sys.stderr.write("no C compiler found\n")
        return -1

    argv = [rumur_bin]
    # if this hardware does not support 5-level paging, we can more aggressively
    # compress pointers
    if has_no_la57():
        argv += ["--pointer-bits", "48"]
    argv += args[1:] + ["--output", "/dev/stdout"]

    # Generate the checker
    print("Generating the checker...")
    rumur_proc = sp.Popen(argv, stdin=sp.PIPE, stdout=sp.PIPE)
    stdout, _ = rumur_proc.communicate()
    if rumur_proc.returncode != 0:
        return rumur_proc.returncode
    checker_c = stdout

    ok = True

    # Setup a temporary directory in which to generate the checker
    with tempfile.TemporaryDirectory() as t:
        tmp = Path(t)

        # Compile the checker
        if ok:
            print("Compiling the checker...")
            aout = tmp / "a.out"
            argv = (
                [CC, "-std=c11"]
                + optimisation_flags()
                + ["-o", str(aout), "-x", "c", "-", "-lpthread"]
            )
            if needs_libatomic():
                argv.append("-latomic")
            cc_proc = sp.Popen(argv, stdin=sp.PIPE)
            cc_proc.communicate(checker_c)
            ok &= cc_proc.returncode == 0

        # Run the checker
        if ok:
            print("Running the checker...")
            ret = sp.call([str(aout)])
            ok &= ret == 0

    return 0 if ok else -1


if __name__ == "__main__":
    try:
        sys.exit(main(sys.argv))
    except KeyboardInterrupt:
        sys.exit(130)
