/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include "adapter.h"

#include <GtiMacros.h>
#include <mpi.h>
#include <pnmpi/xmpi.h>

using namespace gti;
using namespace must;

/* Initialize the module.
 *
 * The following macros will take care about initializing the macro and provide
 * necessary functions for GTI. */
mGET_INSTANCE_FUNCTION(MpiADadapter);
mFREE_INSTANCE_FUNCTION(MpiADadapter);
mPNMPI_REGISTRATIONPOINT_FUNCTION(MpiADadapter);

/**
 * Map an internal callback to the related MUST event function.
 *
 * This macro takes the name of a callback and maps the weaver-generated wrapper
 * function to the internal callback function pointer.
 *
 *
 * @param name name of the callback event
 */
#define get_must_callback(name)                                                                    \
    getWrapperFunction("MUST_MPIADT_callback_" #name "_pre", (GTI_Fct_t*)&callback_##name##_pre);  \
    getWrapperFunction("MUST_MPIADT_callback_" #name "_post", (GTI_Fct_t*)&callback_##name##_post)

/**
 * Constructor.
 *
 *
 * @param instanceName name of this module instance
 */
MpiADadapter::MpiADadapter(const char* instanceName)
    : ModuleBase<MpiADadapter, I_MpiADadapter>(instanceName)
{
    /* Get submodules.
     *
     * A pointer to the dependent modules will be stored in member variables,
     * making them accessible for the lifetime of this object. */
    std::vector<I_Module*> subModInstances = createSubModuleInstances();
    parallelInit = (I_InitParallelId*)subModInstances[0];
    locationInit = (I_InitLocationId*)subModInstances[1];

    /* Get the function pointers for all MUST MpiADT tool interface events.
     * These will be generated by the GTI weaver as defined in the MpiADT_api
     * specification file.
     *
     * NOTE: These callbacks will be available, if analyses are mapped on them.
     *       As all callback functions check for a valid pointer, there will be
     *       no further checks, whether the callback is available. */

    get_must_callback(Allgather);
    get_must_callback(Allreduce);
    get_must_callback(Alltoall);
    get_must_callback(Gather);
    get_must_callback(Isend);
    get_must_callback(Irecv);
    get_must_callback(Wait);
    get_must_callback(Waitall);
}

/* Get the parallel ID for a specific call.
 *
 * For a detailed documentation see the related header file.
 */
MustParallelId MpiADadapter::getParallelId()
{
    MustParallelId id;
    this->parallelInit->init(&id);
    return id;
}

/* Get the location ID for a specific call.
 *
 * For a detailed documentation see the related header file.
 */
MustLocationId MpiADadapter::getLocationId(const void* codeptr_ra)
{
    MustLocationId id;
    this->locationInit->initCodePtr(&id, codeptr_ra);
    return id;
}

/* Notify the analysis about the MPI stack being shutdown.
 *
 * For a detailed documentation see the related header file.
 */
void MpiADadapter::finish()
{
    /* Shutdown the MpiADT runtime, before the MPI runtime is shutdown. This is
     * required, as GTI will send its events via MPI messages and MpiADT
     * shutdown events couldn't be captured by analyses otherwise. */
}

/**
 * MPIADT `Allgather` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_MPIADT_callback_Allgather*` event. Depending on @p endpoint, the event will
 * have the postfix `pre` or  `post`.
 *
 *
 * @param endpoint       begin or end callback
 * @param codeptr
 */
static void on_MPIADT_callback_Allgather(
    const void* sendbuf,
    int sendcount,
    MPI_Datatype sendtype,
    void* recvbuf,
    int recvcount,
    MPI_Datatype recvtype,
    MPI_Comm comm,
    MPIADT_endpoint_t endpoint,
    void** tool_data,
    const void* codeptr)
{
    MpiADadapter* a = MpiADadapter::getInstance("");
    if (endpoint == MPIADT_endpoint_begin) {
        if (a->callback_Allgather_pre)
            a->callback_Allgather_pre(
                a->getParallelId(),
                a->getLocationId(codeptr),
                sendbuf,
                sendcount,
                sendtype,
                recvbuf,
                recvcount,
                recvtype,
                comm);
    } else {
        if (a->callback_Allgather_post)
            a->callback_Allgather_post(
                a->getParallelId(),
                a->getLocationId(codeptr),
                sendbuf,
                sendcount,
                sendtype,
                recvbuf,
                recvcount,
                recvtype,
                comm);
    }
}

/**
 * MPIADT `Allreduce` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_MPIADT_callback_Allreduce*` event. Depending on @p endpoint, the event will
 * have the postfix `pre` or  `post`.
 *
 *
 * @param endpoint       begin or end callback
 * @param codeptr
 */
static void on_MPIADT_callback_Allreduce(
    const void* sendbuf,
    void* recvbuf,
    int count,
    MPI_Datatype datatype,
    MPI_Op op,
    MPI_Comm comm,
    MPIADT_endpoint_t endpoint,
    void** tool_data,
    const void* codeptr)
{
    MpiADadapter* a = MpiADadapter::getInstance("");
    if (endpoint == MPIADT_endpoint_begin) {
        if (a->callback_Allreduce_pre)
            a->callback_Allreduce_pre(
                a->getParallelId(),
                a->getLocationId(codeptr),
                sendbuf,
                recvbuf,
                count,
                datatype,
                op,
                comm);
    } else {
        if (a->callback_Allreduce_post)
            a->callback_Allreduce_post(
                a->getParallelId(),
                a->getLocationId(codeptr),
                sendbuf,
                recvbuf,
                count,
                datatype,
                op,
                comm);
    }
}

/**
 * MPIADT `Alltoall` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_MPIADT_callback_Alltoall*` event. Depending on @p endpoint, the event will
 * have the postfix `pre` or  `post`.
 *
 *
 * @param endpoint       begin or end callback
 * @param codeptr
 */
static void on_MPIADT_callback_Alltoall(
    const void* sendbuf,
    int sendcount,
    MPI_Datatype sendtype,
    void* recvbuf,
    int recvcount,
    MPI_Datatype recvtype,
    MPI_Comm comm,
    MPIADT_endpoint_t endpoint,
    void** tool_data,
    const void* codeptr)
{
    MpiADadapter* a = MpiADadapter::getInstance("");
    if (endpoint == MPIADT_endpoint_begin) {
        if (a->callback_Alltoall_pre)
            a->callback_Alltoall_pre(
                a->getParallelId(),
                a->getLocationId(codeptr),
                sendbuf,
                sendcount,
                sendtype,
                recvbuf,
                recvcount,
                recvtype,
                comm);
    } else {
        if (a->callback_Alltoall_post)
            a->callback_Alltoall_post(
                a->getParallelId(),
                a->getLocationId(codeptr),
                sendbuf,
                sendcount,
                sendtype,
                recvbuf,
                recvcount,
                recvtype,
                comm);
    }
}

/**
 * MPIADT `Gather` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_MPIADT_callback_Gather*` event. Depending on @p endpoint, the event will
 * have the postfix `pre` or  `post`.
 *
 *
 * @param endpoint       begin or end callback
 * @param codeptr
 */
static void on_MPIADT_callback_Gather(
    const void* sendbuf,
    int sendcnt,
    MPI_Datatype sendtype,
    void* recvbuf,
    int recvcnt,
    MPI_Datatype recvtype,
    int root,
    MPI_Comm comm,
    MPIADT_endpoint_t endpoint,
    void** tool_data,
    const void* codeptr)
{
    MpiADadapter* a = MpiADadapter::getInstance("");
    if (endpoint == MPIADT_endpoint_begin) {
        if (a->callback_Gather_pre)
            a->callback_Gather_pre(
                a->getParallelId(),
                a->getLocationId(codeptr),
                sendbuf,
                sendcnt,
                sendtype,
                recvbuf,
                recvcnt,
                recvtype,
                root,
                comm);
    } else {
        if (a->callback_Gather_post)
            a->callback_Gather_post(
                a->getParallelId(),
                a->getLocationId(codeptr),
                sendbuf,
                sendcnt,
                sendtype,
                recvbuf,
                recvcnt,
                recvtype,
                root,
                comm);
    }
}

/**
 * MPIADT `Irecv` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_MPIADT_callback_Irecv*` event. Depending on @p endpoint, the event will
 * have the postfix `pre` or  `post`.
 *
 *
 * @param endpoint       begin or end callback
 * @param codeptr
 */
static void on_MPIADT_callback_Irecv(
    void* buf,
    int count,
    MPI_Datatype datatype,
    int source,
    int tag,
    MPI_Comm comm,
    MPI_Request* request,
    MPIADT_endpoint_t endpoint,
    void** tool_data,
    const void* codeptr)
{
    MpiADadapter* a = MpiADadapter::getInstance("");
    if (endpoint == MPIADT_endpoint_begin) {
        if (a->callback_Irecv_pre)
            a->callback_Irecv_pre(
                a->getParallelId(),
                a->getLocationId(codeptr),
                buf,
                count,
                datatype,
                source,
                tag,
                comm,
                request);
    } else {
        if (a->callback_Irecv_post)
            a->callback_Irecv_post(
                a->getParallelId(),
                a->getLocationId(codeptr),
                buf,
                count,
                datatype,
                source,
                tag,
                comm,
                request);
    }
}

/**
 * MPIADT `Isend` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_MPIADT_callback_Isend*` event. Depending on @p endpoint, the event will
 * have the postfix `pre` or  `post`.
 *
 *
 * @param endpoint       begin or end callback
 * @param codeptr
 */
static void on_MPIADT_callback_Isend(
    const void* buf,
    int count,
    MPI_Datatype datatype,
    int dest,
    int tag,
    MPI_Comm comm,
    MPI_Request* request,
    MPIADT_endpoint_t endpoint,
    void** tool_data,
    void* codeptr)
{
    MpiADadapter* a = MpiADadapter::getInstance("");
    if (endpoint == MPIADT_endpoint_begin) {
        if (a->callback_Isend_pre)
            a->callback_Isend_pre(
                a->getParallelId(),
                a->getLocationId(codeptr),
                buf,
                count,
                datatype,
                dest,
                tag,
                comm,
                request);
    } else {
        if (a->callback_Isend_post)
            a->callback_Isend_post(
                a->getParallelId(),
                a->getLocationId(codeptr),
                buf,
                count,
                datatype,
                dest,
                tag,
                comm,
                request);
    }
}

/**
 * MPIADT `Wait` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_MPIADT_callback_Wait*` event. Depending on @p endpoint, the event will
 * have the postfix `pre` or  `post`.
 *
 *
 * @param endpoint       begin or end callback
 * @param codeptr
 */
static void on_MPIADT_callback_Wait(
    MPI_Request* request,
    MPI_Status* status,
    MPIADT_endpoint_t endpoint,
    void** tool_data,
    const void* codeptr)
{
    MpiADadapter* a = MpiADadapter::getInstance("");
    if (endpoint == MPIADT_endpoint_begin) {
        if (a->callback_Wait_pre)
            a->callback_Wait_pre(a->getParallelId(), a->getLocationId(codeptr), request, status);
    } else {
        if (a->callback_Wait_post)
            a->callback_Wait_post(a->getParallelId(), a->getLocationId(codeptr), request, status);
    }
}

/**
 * MPIADT `Waitall` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_MPIADT_callback_Waitall*` event. Depending on @p endpoint, the event will
 * have the postfix `pre` or  `post`.
 *
 *
 * @param endpoint       begin or end callback
 * @param codeptr
 */
static void on_MPIADT_callback_Waitall(
    int count,
    MPI_Request array_of_requests[],
    MPI_Status array_of_statuses[],
    MPIADT_endpoint_t endpoint,
    void** tool_data,
    const void* codeptr)
{
    MpiADadapter* a = MpiADadapter::getInstance("");
    if (endpoint == MPIADT_endpoint_begin) {
        if (a->callback_Waitall_pre)
            a->callback_Waitall_pre(
                a->getParallelId(),
                a->getLocationId(codeptr),
                count,
                array_of_requests,
                array_of_statuses);
    } else {
        if (a->callback_Waitall_post)
            a->callback_Waitall_post(
                a->getParallelId(),
                a->getLocationId(codeptr),
                count,
                array_of_requests,
                array_of_statuses);
    }
}

/* The following code is an adapted and commented copy of the LLVM runtime OMPT
 * test case, licensed under the Apache License. */

/**
 * Register a callback function.
 *
 * This macro registers a new callback with a function matching the callback's
 * name.
 *
 *
 * @param name name of the callback
 * @param type type of the callback function
 */
#define register_callback_t(name, type)                                                            \
    do {                                                                                           \
        type f_##name = &on_MPIADT_callback_##name;                                                \
        if (register_fct(MPIADT_##name, (MPIADT_callback_t)f_##name) == MPIADT_set_never)          \
            printf("Could not register callback '" #name "'\n");                                   \
    } while (0)

/**
 * Register a callback function.
 *
 * This macro registers a new callback with a function matching the callback's
 * name.
 *
 * @note The callback needs to match a typedef with same name and `_t` postfix.
 *
 *
 * @param name name of the callback
 */
#define register_callback(name) register_callback_t(name, MPIADT_##name##_t)

/**
 * OMPT initialize function.
 *
 * This function will be called on MpiADT runtime initialization and registers
 * all callbacks provided by this adapter.
 *
 *
 * @param lookup             Lookup OMPT internal functions with this function
 *                           pointer.
 * @param initial_device_num
 * @param tool_data
 *
 * @return This function always returns `1`, indicating that this tool is
 *         activated.
 */
extern "C" int MPIADT_start_tool(MPIADT_register_callback_t register_fct)
{

    /* Register all callbacks defined in this adapter to be triggered by the
     * MpiADT runtime. */
    register_callback(Allgather);
    register_callback(Allreduce);
    register_callback(Alltoall);
    register_callback(Gather);
    register_callback(Irecv);
    register_callback(Isend);
    register_callback(Wait);
    register_callback(Waitall);

    /* Return 1 to indicate that this tool is activated and its callbacks should
     * be called. No further tools will be activated. */
    return 1;
}
