/* This file is part of GTI (Generic Tool Infrastructure)
 *
 * Copyright (C)
 *  2008-2019 ZIH, Technische Universitaet Dresden, Federal Republic of Germany
 *  2008-2019 Lawrence Livermore National Laboratories, United States of America
 *  2013-2019 RWTH Aachen University, Federal Republic of Germany
 *
 * See the LICENSE file in the package base directory for details
 */

/**
 * @file CStratCrashHandling.cpp
 *  Extension to communication strategies to handle crashes of the application.
 *
 *  This Extensions provides handlers for MPI-Errors and signals. By the use of
 * this handlers the tool processes can finish all analyses before the
 * application gets stopped.
 *
 *
 * @author Joachim Protze
 * @date 20.03.2012
 *
 */

#include "CStratCrashHandling.h"

#include <array>
#include <signal.h>
#include <mpi.h>
#include <stdio.h>
#include <stdlib.h>
#include <pnmpimod.h>
#include <unistd.h>
#include <assert.h>
#include <gtiConfig.h>
#include <sys/types.h>
#include <unistd.h>
#include <execinfo.h>
#include <vector>
#include <mutex>
#include <functional>
#include <iostream>
#include <map>
#include "safe_ptr.h"

#include "ErrhandlerTracker.hpp"
#include "CProtMpiSplitWorld.h"

#ifndef _EXTERN_C_
#ifdef __cplusplus
#define _EXTERN_C_ extern "C"
#else /* __cplusplus */
#define _EXTERN_C_
#endif /* __cplusplus */
#endif /* _EXTERN_C_ */

#define GTI_SPLIT_MODULE_NAME "split_processes"

/* Switch between MPI-1.2 and MPI-2 errorhandler */
#ifdef HAVE_MPI_COMM_CREATE_ERRHANDLER
#define GTI_COMM_CREATE_ERRHANDLER(f, e) XMPI_Comm_create_errhandler_NewStack(::stack, f, e)
#else
#define GTI_COMM_CREATE_ERRHANDLER(f, e) XMPI_Errhandler_create_NewStack(::stack, f, e)
#endif

#ifdef HAVE_MPI_COMM_SET_ERRHANDLER
#define GTI_COMM_SET_ERRHANDLER(f, e)                                                              \
    if (f != MPI_COMM_NULL && e != 0)                                                              \
    XMPI_Comm_set_errhandler_NewStack(::stack, f, e)
#else
#define GTI_COMM_SET_ERRHANDLER(f, e)                                                              \
    if (f != MPI_COMM_NULL && e != 0)                                                              \
    XMPI_Errhandler_set_NewStack(::stack, f, e)
#endif

#define GTI_COMM_SET_ERR_HANDLER(c) GTI_COMM_SET_ERRHANDLER(*c, gtiMpiCommErrorhandler)

#define GTI_WIN_SET_ERR_HANDLER(w)                                                                 \
    if (*w != MPI_WIN_NULL)                                                                        \
    XMPI_Win_set_errhandler_NewStack(::stack, *w, gtiMpiWinErrorhandler)

#define GTI_FILE_SET_ERR_HANDLER(f)                                                                \
    if (*f != MPI_FILE_NULL)                                                                       \
    XMPI_File_set_errhandler_NewStack(::stack, *f, gtiMpiFileErrorhandler)

using namespace gti;

typedef void (*sighandler_t)(int);
static PNMPI_modHandle_t stack{0};

static void set_signalhandlers(sighandler_t handler)
{
    signal(SIGSEGV, handler);
    signal(SIGINT, handler);
    signal(SIGHUP, handler);
    signal(SIGABRT, handler);
    signal(SIGTERM, handler);
    signal(SIGUSR2, handler);
    signal(SIGQUIT, handler);
    signal(SIGALRM, handler);
}

__attribute__((destructor)) static void disable_signalhandlers() { set_signalhandlers(SIG_DFL); }

#define CALLSTACK_SIZE 20

static void print_stack(void)
{
    int nptrs;
    void* buf[CALLSTACK_SIZE + 1];

    nptrs = backtrace(buf, CALLSTACK_SIZE);

    backtrace_symbols_fd(buf, nptrs, STDOUT_FILENO);
}

/** The callbacks to be invoked when panicking. */
static std::vector<gti::PanicCallbackType*> panic_callbacks{};
/** The mutex for ::panic_callbacks. */
std::mutex panic_callbacks_mutex;

static MPI_Errhandler gtiMpiCommErrorhandler = MPI_ERRHANDLER_NULL;
#if defined(HAS_SESSIONS_SUPPORT)
static MPI_Errhandler gtiMpiSessionErrorhandler = MPI_ERRHANDLER_NULL;
#endif
static MPI_Errhandler gtiMpiWinErrorhandler = MPI_ERRHANDLER_NULL;
static MPI_Errhandler gtiMpiFileErrorhandler = MPI_ERRHANDLER_NULL;
static int gtiMpiCrashRank = -1, gtiMpiCrashSize = -1, doStacktraceOnTerm = 0;
static unsigned gtiCrashRankSleepTime = GTI_CRASH_SLEEP_TIME;

namespace
{

ErrhandlerTracker<MPI_Comm> CommErrhandlerTracker{};
#if defined(HAS_SESSIONS_SUPPORT)
ErrhandlerTracker<MPI_Session> SessionErrhandlerTracker{};
#endif
ErrhandlerTracker<MPI_Win> WinErrhandlerTracker{};
ErrhandlerTracker<MPI_File> FileErrhandlerTracker{};

} // namespace

/**
 * Invoke the callbacks.
 */
static void notifyPanics()
{

    // Ideally this should be protected by a lock guard. Unfortunately are the only
    // async-signal-safe pthread functions pthread_kill, pthread_self and pthread_sigmask.
    // So let's just hope for the best!
    for (auto* const callback : panic_callbacks) {
        callback();
    }
}

/**
 * The common code for handling MPI errors.
 * It
 *  - prints a message and a stack trace.
 *  - notifies registered panic callbacks
 *  - exits the application
 *
 * @param errCode the MPI error code
 * @param userHandler callable that is invoked after sleeping
 */
static void commonMpiErrHandler(int* errCode, const std::function<void()>& userHandler)
{
    disable_signalhandlers();
    printf(
        "rank %i (of %i), pid %i caught MPI error nr %i\n",
        gtiMpiCrashRank,
        gtiMpiCrashSize,
        getpid(),
        *errCode);
    char error_string[BUFSIZ];
    int length_of_error_string;
    XMPI_Error_string_NewStack(::stack, *errCode, error_string, &length_of_error_string);
    printf("%s\n", error_string);
    print_stack();

    notifyPanics();

    printf("Waiting up to %i seconds for analyses to be finished.\n", gtiCrashRankSleepTime);
    sleep(gtiCrashRankSleepTime);

    userHandler();

    exit(1);
}

#if defined(HAS_SESSIONS_SUPPORT)
/**
 * MPI errorhandler that dispatches for errors from the MPI Session.
 *
 * @param session the session associated with the error
 * @param errCode the MPI error code
 * @param ...
 */
static void myMpiSessionErrHandler(MPI_Session* session, int* errCode, ...)
{
    commonMpiErrHandler(errCode, [&]() {
        auto* const user_handler = SessionErrhandlerTracker.getErrhandlerFunc(*session);
        if (user_handler != nullptr) {
            user_handler(session, errCode);
        }
    });
}
#endif

/**
 * MPI errorhandler that dispatches for errors from MPI communicators.
 *
 * @param comm the session associated with the error
 * @param errCode the MPI error code
 * @param ...
 */
static void myMpiErrHandler(MPI_Comm* comm, int* errCode, ...)
{
    commonMpiErrHandler(errCode, [&]() {
        auto* const user_handler = CommErrhandlerTracker.getErrhandlerFunc(*comm);
        if (user_handler != nullptr) {
            user_handler(comm, errCode);
        }
    });
}

/**
 * MPI errorhandler that dispatches for errors from MPI_Win objects.
 *
 * @param session the session associated with the error
 * @param errCode the MPI error code
 * @param ...
 */
static void myMpiWinErrHandler(MPI_Win* win, int* errCode, ...)
{
    commonMpiErrHandler(errCode, [&]() {
        auto* const user_handler = WinErrhandlerTracker.getErrhandlerFunc(*win);
        if (user_handler != nullptr) {
            user_handler(win, errCode);
        }
    });
}

/**
 * MPI errorhandler that dispatches for errors from MPI_File objects.
 *
 * @param session the session associated with the error
 * @param errCode the MPI error code
 * @param ...
 */
static void myMpiFileErrHandler(MPI_File* file, int* errCode, ...)
{
    commonMpiErrHandler(errCode, [&]() {
        auto* const user_handler = FileErrhandlerTracker.getErrhandlerFunc(*file);
        if (user_handler != nullptr) {
            user_handler(file, errCode);
        }
    });
}

/**
 * Handler that is called on signals.
 *
 * @param signum the signal number
 */
void mySignalHandler(int signum)
{
    disable_signalhandlers();
    printf(
        "rank %i (of %i), pid %i caught signal nr %i\n",
        gtiMpiCrashRank,
        gtiMpiCrashSize,
        getpid(),
        signum);
    int finalized = 0;
    MPI_Finalized(&finalized);
    if (signum == SIGINT || signum == SIGKILL) {
        print_stack();
        if (!static_cast<bool>(finalized)) {
            MPI_Abort(MPI_COMM_WORLD, signum + 128);
        } else {
            _exit(signum + 128);
        }
    }
    if (signum == SIGTERM || signum == SIGUSR2) {
        if (doStacktraceOnTerm) {
            print_stack();
            fflush(stdout);
            sleep(1);
        }
        if (!static_cast<bool>(finalized)) {
            MPI_Abort(MPI_COMM_WORLD, signum + 128);
        } else {
            _exit(signum + 128);
        }
    }
    print_stack();

    notifyPanics();

    printf("Waiting up to %i seconds for analyses to be finished.\n", gtiCrashRankSleepTime);
    sleep(gtiCrashRankSleepTime);
    _exit(1);
}

/**
 * Install the GTI error handlers on the predefined world model communicators.
 */
void gti::crashHandlingInitWorldModel()
{
    static std::once_flag inited;
    std::call_once(inited, []() {
#ifdef HAVE_MPI_COMM_SET_ERRHANDLER
        XMPI_Comm_set_errhandler_NewStack(::stack, MPI_COMM_SELF, gtiMpiCommErrorhandler);
        XMPI_Comm_set_errhandler_NewStack(::stack, MPI_COMM_WORLD, gtiMpiCommErrorhandler);
#else
            XMPI_Errhandler_set_NewStack(::stack, MPI_COMM_SELF, gtiMpiCommErrorhandler);
            XMPI_Errhandler_set_NewStack(::stack, MPI_COMM_WORLD, gtiMpiCommErrorhandler);
#endif
    });
}

void gti::crashHandlingInit()
{
    static std::once_flag inited;
    std::call_once(inited, [] {
        if (getenv("INTERNAL_GTI_STACKTRACE_ON_TERM")) {
            doStacktraceOnTerm = 1;
        }

        const char* const gti_crash_sleep_time_str = getenv("GTI_CRASH_SLEEP_TIME");
        if (gti_crash_sleep_time_str != nullptr) {
            errno = 0;
            gtiCrashRankSleepTime = strtol(gti_crash_sleep_time_str, nullptr, 0);
            if (errno) {
                fprintf(
                    stderr,
                    "WARNING: Invalid value for environment variable GTI_CRASH_SLEEP_TIME "
                    "was ignored. Using default value of %d.\n",
                    GTI_CRASH_SLEEP_TIME);
                gtiCrashRankSleepTime = GTI_CRASH_SLEEP_TIME;
            }
        }

        set_signalhandlers(mySignalHandler);
    });
}

void gti::attachErrorhandler(MPI_Comm comm)
{
    assert(gtiMpiCommErrorhandler != MPI_ERRHANDLER_NULL);
    GTI_COMM_SET_ERRHANDLER(comm, gtiMpiCommErrorhandler);

    XMPI_Comm_size_NewStack(::stack, comm, &gtiMpiCrashSize);
    XMPI_Comm_rank_NewStack(::stack, comm, &gtiMpiCrashRank);
}

template <typename MPI_Object>
static auto noOpErrhandler(MPI_Object* /*unused*/, int* /*unused*/, ...) -> void
{
}

auto gti::registerPredefinedErrhandlers() -> void
{
    // Register errorhandlers with no-op functions for default errorhandlers. Rationale: The app
    // might pass a predefined errorhandler.
    std::array<MPI_Errhandler, 3> const predefinedErrhandlers = {
        MPI_ERRORS_ARE_FATAL,
        MPI_ERRORS_RETURN,
#if defined(HAVE_MPI_ERRORS_ABORT)
        MPI_ERRORS_ABORT,
#endif
    };

    for (auto const errhandler : predefinedErrhandlers) { // NOLINT(*-qualified-auto)
        CommErrhandlerTracker.add(noOpErrhandler, errhandler);
#if defined(HAS_SESSIONS_SUPPORT)
        SessionErrhandlerTracker.add(noOpErrhandler, errhandler);
#endif
        WinErrhandlerTracker.add(noOpErrhandler, errhandler);
        FileErrhandlerTracker.add(noOpErrhandler, errhandler);
    }
}

void gti::crashHandlingInitErrhandlers()
{
    static std::once_flag inited;
    std::call_once(inited, []() {
#ifndef NDEBUG
        int err =
#endif
            PNMPI_Service_GetStackByName("level_0", &stack);
        assert(err == PNMPI_SUCCESS);

        GTI_COMM_CREATE_ERRHANDLER(myMpiErrHandler, &gtiMpiCommErrorhandler);
        XMPI_Win_create_errhandler_NewStack(::stack, myMpiWinErrHandler, &gtiMpiWinErrorhandler);
        XMPI_File_create_errhandler_NewStack(::stack, myMpiFileErrHandler, &gtiMpiFileErrorhandler);
    });
}

#if defined(HAS_SESSIONS_SUPPORT)
auto gti::getSessionErrhandler() -> MPI_Errhandler
{
    if (gtiMpiSessionErrorhandler == MPI_ERRHANDLER_NULL) {
        XMPI_Session_create_errhandler_NewStack(
            ::stack,
            myMpiSessionErrHandler,
            &gtiMpiSessionErrorhandler);
    }
    return gtiMpiSessionErrorhandler;
}
#endif

auto gti::getCommErrhandler() -> MPI_Errhandler
{
    assert(gtiMpiCommErrorhandler != MPI_ERRHANDLER_NULL);
    return gtiMpiCommErrorhandler;
}

static auto warn_if_errors_return_handler(MPI_Errhandler errh) -> void
{
    if (errh == MPI_ERRORS_RETURN) {
        std::cout << "GTI does not support usage of the errorhandler MPI_ERRORS_RETURN. The "
                     "errorhandler will not return.\n"
                  << std::flush;
    }
}

auto gti::onPanic(void (*callback)()) -> void
{
    std::lock_guard<std::mutex> guard{panic_callbacks_mutex};
    panic_callbacks.push_back(callback);
};

// wrap does not generate MPI_Comm_spawn* functions
#if 0
#ifdef HAVE_MPI_COMM_SPAWN

// MPI-2
_EXTERN_C_ int MPI_Comm_spawn(
#ifndef HAVE_MPI_NO_CONST_CORRECTNESS
    const
#endif /*HAVE_MPI_NO_CONST_CORRECTNESS*/
    char* command,
    char* argv[],
    int maxprocs,
    MPI_Info info,
    int root,
    MPI_Comm comm,
    MPI_Comm* intercomm,
    int array_of_errcodes[])
{
    int ret =
        XMPI_Comm_spawn(command, argv, maxprocs, info, root, comm, intercomm, array_of_errcodes);
    GTI_COMM_SET_ERRHANDLER(*intercomm, gtiMpiCommErrorhandler);
    return ret;
}

// MPI-2
_EXTERN_C_ int MPI_Comm_spawn_multiple(
    int count,
    char* array_of_commands[],
    char** array_of_argv[],
#ifndef HAVE_MPI_NO_CONST_CORRECTNESS
    const
#endif /*HAVE_MPI_NO_CONST_CORRECTNESS*/
    int array_of_maxprocs[],
#ifndef HAVE_MPI_NO_CONST_CORRECTNESS
    const
#endif /*HAVE_MPI_NO_CONST_CORRECTNESS*/
    MPI_Info array_of_info[],
    int root,
    MPI_Comm comm,
    MPI_Comm* intercomm,
    int array_of_errcodes[])
{
    int ret = XMPI_Comm_spawn_multiple(
        count,
        array_of_commands,
        array_of_argv,
        array_of_maxprocs,
        array_of_info,
        root,
        comm,
        intercomm,
        array_of_errcodes);
    GTI_COMM_SET_ERRHANDLER(*intercomm, gtiMpiCommErrorhandler);
    return ret;
}

#endif // HAVE_MPI_COMM_SPAWN
#endif

/** Track omitted calls of MPI_xxx_get_errhandler to prevent breaking the reference count of the mpi
    library. Used in `CStratCrashHandling.w`. */
static sf::safe_ptr<std::map<MPI_Errhandler, int>> omittedGetErrhandlerCalls{};

#include "CStratCrashHandling.wrap.cpp" // NOLINT(bugprone-suspicious-include)
