/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include "adapter.h"

#include "BaseIds.h"
#include "PrefixedOstream.hpp"
#include "omp-tools.h"
#include <GtiEnums.h>
#include <GtiMacros.h>
#include <I_Module.h>
#include <atomic>
#include <vector>

#define OMPT_MULTIPLEX_TOOL_NAME "MUST"
#include "ompt-multiplex.h"

using namespace gti;
using namespace must;

/* Initialize the module.
 *
 * The following macros will take care about initializing the macro and provide
 * necessary functions for GTI. */
mGET_INSTANCE_FUNCTION(OpenMPadapter);
mFREE_INSTANCE_FUNCTION(OpenMPadapter);
mPNMPI_REGISTRATIONPOINT_FUNCTION(OpenMPadapter);

namespace must
{
namespace
{
/**
 * @brief Helper to determine if an instance of OpenMPadapter has been created already.
 *
 * It may happen that OMPT invokes a registered callback before an instance of OpenMPadapter was
 * created. Since instances should normally be only created through GTI provided macros we have to
 * ensure that the callback does not call OpenMPadapter::getInstance as the first one.
 */
class OpenMPadapterInitGuard
{
    /**
     * Gives access to #notify so that OpenMPadapter may call it in its constructor.
     */
    friend OpenMPadapter::OpenMPadapter(const char* instanceName);
    static std::atomic<bool> instance_created;

    /**
     * Notify the CallbackGuard that an instance of OpenMPadapter was constructed.
     */
    static auto notify() noexcept -> void
    {
        instance_created.store(true, std::memory_order_release);
    }

  public:
    /**
     * @return True iff an instance of OpenMPadapter has been created.
     */
    static auto is_some_instance_created() noexcept -> bool
    {
        return instance_created.load(std::memory_order_acquire);
    }
};
std::atomic<bool> OpenMPadapterInitGuard::instance_created{false};
} // namespace
} // namespace must

/**
 * OMPT `set_callback` function.
 *
 * This function will be used to register callbacks in the OpenMP runtime.
 *
 * @note The OpenMP tools interface does not provide external function symbols
 *       for its service functions. Therefore, this variable will be used to
 *       store a function pointer returned by the OpenMP lookup function.
 */
static ompt_set_callback_t ompt_set_callback = nullptr;

/**
 * OMPT `get_unique_id` function.
 *
 * This function will be used to generate a unique ID in the begin event
 * callbacks to match beginning and ending events.
 *
 * @note The OpenMP tools interface does not provide external function symbols
 *       for its service functions. Therefore, this variable will be used to
 *       store a function pointer returned by the OpenMP lookup function.
 */
static ompt_get_unique_id_t ompt_get_unique_id = nullptr;

/**
 * OMPT `ompt_finalize_tool` function.
 *
 * This function will be used trigger the shutdown of the OpenMP runtime.
 *
 * @note The OpenMP tools interface does not provide external function symbols
 *       for its service functions. Therefore, this variable will be used to
 *       store a function pointer returned by the OpenMP lookup function.
 */
static ompt_finalize_tool_t ompt_finalize_tool = nullptr;

/**
 * Map an internal callback to the related MUST event function.
 *
 * This macro takes the name of a callback and maps the weaver-generated wrapper
 * function to the internal callback function pointer.
 *
 *
 * @param name name of the callback event
 */
#define get_must_callback(name)                                                                    \
    getWrapperFunction("MUST_ompt_callback_" #name, (GTI_Fct_t*)&callback_##name)

/**
 * Constructor.
 *
 *
 * @param instanceName name of this module instance
 */
OpenMPadapter::OpenMPadapter(const char* instanceName)
    : ModuleBase<OpenMPadapter, I_OpenMPadapter>(instanceName)
{
    /* Get submodules.
     *
     * A pointer to the dependent modules will be stored in member variables,
     * making them accessible for the lifetime of this object. */
    std::vector<I_Module*> subModInstances = createSubModuleInstances();
    parallelInit = (I_InitParallelId*)subModInstances[0];
    locationInit = (I_InitLocationId*)subModInstances[1];

    /* Get the function pointers for all MUST OpenMP tool interface events.
     * These will be generated by the GTI weaver as defined in the OpenMP_api
     * specification file.
     *
     * NOTE: These callbacks will be available, if analyses are mapped on them.
     *       As all callback functions check for a valid pointer, there will be
     *       no further checks, whether the callback is available. */

    get_must_callback(wait_lock);
    get_must_callback(wait_nest_lock);
    get_must_callback(wait_critical);
    get_must_callback(wait_atomic);
    get_must_callback(wait_ordered);

    get_must_callback(acquired_lock);
    get_must_callback(acquired_nest_lock);
    get_must_callback(acquired_critical);
    get_must_callback(acquired_atomic);
    get_must_callback(acquired_ordered);

    get_must_callback(release_lock);
    get_must_callback(release_nest_lock);
    get_must_callback(release_critical);
    get_must_callback(release_atomic);
    get_must_callback(release_ordered);

    get_must_callback(acquired_nest_lock_next);
    get_must_callback(release_nest_lock_prev);

    get_must_callback(barrier_begin);
    get_must_callback(barrier_end);
    get_must_callback(taskwait_begin);
    get_must_callback(taskwait_end);
    get_must_callback(taskgroup_begin);
    get_must_callback(taskgroup_end);

    get_must_callback(wait_barrier_begin);
    get_must_callback(wait_barrier_end);
    get_must_callback(wait_taskwait_begin);
    get_must_callback(wait_taskwait_end);
    get_must_callback(wait_taskgroup_begin);
    get_must_callback(wait_taskgroup_end);

    get_must_callback(flush);

    get_must_callback(cancel);

    get_must_callback(initial_task_begin);
    get_must_callback(initial_task_end);
    get_must_callback(implicit_task_begin);
    get_must_callback(implicit_task_end);

    get_must_callback(init_lock);
    get_must_callback(init_nest_lock);

    get_must_callback(destroy_lock);
    get_must_callback(destroy_nest_lock);

    get_must_callback(loop_begin);
    get_must_callback(loop_end);
    get_must_callback(sections_begin);
    get_must_callback(sections_end);
    get_must_callback(single_in_block_begin);
    get_must_callback(single_in_block_end);
    get_must_callback(single_others_begin);
    get_must_callback(single_others_end);
    get_must_callback(distribute_begin);
    get_must_callback(distribute_end);
    get_must_callback(taskloop_begin);
    get_must_callback(taskloop_end);

    get_must_callback(masked_begin);
    get_must_callback(masked_end);

    get_must_callback(parallel_begin);
    get_must_callback(parallel_end);

    get_must_callback(task_create);
    get_must_callback(task_schedule);
    get_must_callback(task_end);
    get_must_callback(task_dependences);
    get_must_callback(task_dependence_pair);

    get_must_callback(thread_begin);
    get_must_callback(thread_end);

    get_must_callback(control_tool);

    get_must_callback(initialize);
    get_must_callback(finalize);

    OpenMPadapterInitGuard::notify();
}

/* Get the parallel ID for a specific call.
 *
 * For a detailed documentation see the related header file.
 */
MustParallelId OpenMPadapter::getParallelId()
{
    MustParallelId id;
    this->parallelInit->init(&id);
    return id;
}

/* Get the location ID for a specific call.
 *
 * For a detailed documentation see the related header file.
 */
MustLocationId OpenMPadapter::getLocationId(const void* codeptr_ra)
{
    MustLocationId id;
    this->locationInit->initCodePtr(&id, codeptr_ra);
    return id;
}

/* Notify the analysis about the MPI stack being shutdown.
 *
 * For a detailed documentation see the related header file.
 */
void OpenMPadapter::finish()
{
    /* Shutdown the OpenMP runtime, before the MPI runtime is shutdown. This is
     * required, as GTI will send its events via MPI messages and OpenMP
     * shutdown events couldn't be captured by analyses otherwise. */
    if (ompt_finalize_tool)
        ompt_finalize_tool();
}

OpenMPadapter::~OpenMPadapter()
{
    if (parallelInit != nullptr) {
        destroySubModuleInstance(static_cast<I_Module*>(parallelInit));
        parallelInit = nullptr;
    }

    if (locationInit != nullptr) {
        destroySubModuleInstance(static_cast<I_Module*>(locationInit));
        locationInit = nullptr;
    }
}

/**
 * Get the value of an ompt data structure.
 *
 *
 * @param data the data structure to be converted
 *
 * @return If @p data is not `NULL`, the data's value, otherwise 0 will be
 *         returned.
 */
static uint64_t getOmptData(ompt_data_t* data) { return data ? data->value : 0; }

/**
 * Redirect a `ompt_callback_mutex_acquire_t` callback to its MUST event.
 *
 *
 * @param name name of the callback event
 */
#define redirect_callback_mutex_acquire(name)                                                      \
    do {                                                                                           \
        if (a != nullptr && a->callback_##name)                                                    \
            a->callback_##name(                                                                    \
                a->getParallelId(),                                                                \
                a->getLocationId(codeptr_ra),                                                      \
                hint,                                                                              \
                impl,                                                                              \
                wait_id,                                                                           \
                reinterpret_cast<MustAddressType>(codeptr_ra));                                    \
    } while (0)

/**
 * Redirect a `ompt_callback_mutex_t` callback to its MUST event.
 *
 *
 * @param name name of the callback event
 */
#define redirect_callback_mutex(name)                                                              \
    do {                                                                                           \
        if (a != nullptr && a->callback_##name)                                                    \
            a->callback_##name(                                                                    \
                a->getParallelId(),                                                                \
                a->getLocationId(codeptr_ra),                                                      \
                wait_id,                                                                           \
                reinterpret_cast<MustAddressType>(codeptr_ra));                                    \
    } while (0)

/**
 * Redirect a `ompt_callback_sync_region_t` callback to its MUST event.
 *
 *
 * @param name name of the callback event
 */
#define redirect_callback_sync_region(name)                                                        \
    do {                                                                                           \
        if (a != nullptr && a->callback_##name)                                                    \
            a->callback_##name(                                                                    \
                a->getParallelId(),                                                                \
                a->getLocationId(codeptr_ra),                                                      \
                kind,                                                                              \
                getOmptData(parallel_data),                                                        \
                getOmptData(task_data),                                                            \
                reinterpret_cast<MustAddressType>(codeptr_ra));                                    \
    } while (0)

/**
 * Redirect a `ompt_callback_implicit_task_t` callback to its MUST event.
 *
 *
 * @param name name of the callback event
 */
#define redirect_callback_implicit_task(name)                                                      \
    do {                                                                                           \
        if (a != nullptr && a->callback_##name)                                                    \
            a->callback_##name(                                                                    \
                a->getParallelId(),                                                                \
                getOmptData(parallel_data),                                                        \
                getOmptData(task_data),                                                            \
                team_size,                                                                         \
                thread_num,                                                                        \
                flags);                                                                            \
    } while (0)

/**
 * Redirect a `ompt_callback_masked_t` callback to its MUST event.
 *
 *
 * @param name name of the callback event
 */
#define redirect_callback_masked(name)                                                             \
    do {                                                                                           \
        if (a != nullptr && a->callback_##name)                                                    \
            a->callback_##name(                                                                    \
                a->getParallelId(),                                                                \
                a->getLocationId(codeptr_ra),                                                      \
                getOmptData(parallel_data),                                                        \
                getOmptData(task_data),                                                            \
                reinterpret_cast<MustAddressType>(codeptr_ra));                                    \
    } while (0)

/**
 * Redirect a `ompt_callback_work_t` callback to its MUST event.
 *
 *
 * @param name name of the callback event
 */
#define redirect_callback_work(name)                                                               \
    do {                                                                                           \
        if (a != nullptr && a->callback_##name)                                                    \
            a->callback_##name(                                                                    \
                a->getParallelId(),                                                                \
                a->getLocationId(codeptr_ra),                                                      \
                getOmptData(parallel_data),                                                        \
                getOmptData(task_data),                                                            \
                count,                                                                             \
                reinterpret_cast<MustAddressType>(codeptr_ra));                                    \
    } while (0)

/**
 * OMPT `mutex_acquire` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_ompt_callback_wait_*` event. Depending on @p kind, the event will have
 * the postfix `lock`, `nest_lock`, `critical`, `atomic` or `ordered`.
 *
 *
 * @param kind       the kind of mutex
 * @param hint
 * @param impl
 * @param wait_id
 * @param codeptr_ra
 */
static void on_ompt_callback_mutex_acquire(
    ompt_mutex_t kind,
    unsigned int hint,
    unsigned int impl,
    ompt_wait_id_t wait_id,
    const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (kind) {
    case ompt_mutex_lock:
        redirect_callback_mutex_acquire(wait_lock);
        break;
    case ompt_mutex_nest_lock:
        redirect_callback_mutex_acquire(wait_nest_lock);
        break;
    case ompt_mutex_critical:
        redirect_callback_mutex_acquire(wait_critical);
        break;
    case ompt_mutex_atomic:
        redirect_callback_mutex_acquire(wait_atomic);
        break;
    case ompt_mutex_ordered:
        redirect_callback_mutex_acquire(wait_ordered);
        break;

    default:
        break;
    }
}

/**
 * OMPT `mutex_acquired` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_ompt_callback_acquired_*` event. Depending on @p kind, the event will
 * have the postfix `lock`, `nest_lock`, `critical`, `atomic` or `ordered`.
 *
 *
 * @param kind       the kind of mutex
 * @param wait_id
 * @param codeptr_ra
 */
static void
on_ompt_callback_mutex_acquired(ompt_mutex_t kind, ompt_wait_id_t wait_id, const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (kind) {
    case ompt_mutex_lock:
        redirect_callback_mutex(acquired_lock);
        break;
    case ompt_mutex_nest_lock:
        redirect_callback_mutex(acquired_nest_lock);
        break;
    case ompt_mutex_critical:
        redirect_callback_mutex(acquired_critical);
        break;
    case ompt_mutex_atomic:
        redirect_callback_mutex(acquired_atomic);
        break;
    case ompt_mutex_ordered:
        redirect_callback_mutex(acquired_ordered);
        break;

    default:
        break;
    }
}

/**
 * OMPT `mutex_released` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_ompt_callback_release_*` event. Depending on @p kind, the event will
 * have the postfix `lock`, `nest_lock`, `critical`, `atomic` or `ordered`.
 *
 *
 * @param kind       the kind of mutex
 * @param wait_id
 * @param codeptr_ra
 */
static void
on_ompt_callback_mutex_released(ompt_mutex_t kind, ompt_wait_id_t wait_id, const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (kind) {
    case ompt_mutex_lock:
        redirect_callback_mutex(release_lock);
        break;
    case ompt_mutex_nest_lock:
        redirect_callback_mutex(release_nest_lock);
        break;
    case ompt_mutex_critical:
        redirect_callback_mutex(release_critical);
        break;
    case ompt_mutex_atomic:
        redirect_callback_mutex(release_atomic);
        break;
    case ompt_mutex_ordered:
        redirect_callback_mutex(release_ordered);
        break;

    default:
        break;
    }
}

/**
 * OMPT `nest_lock` callback handler.
 *
 * This callback handler will redirect the event to MUST. If @p endpoint is
 * `begin`, the event will be `MUST_ompt_callback_acquired_nest_lock_next`. For
 * `end` it is `MUST_ompt_callback_release_nest_lock_prev`.
 *
 *
 * @param endpoint   whether its the begin or end event
 * @param wait_id
 * @param codeptr_ra
 */
static void on_ompt_callback_nest_lock(
    ompt_scope_endpoint_t endpoint,
    ompt_wait_id_t wait_id,
    const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (endpoint) {
    case ompt_scope_begin:
    case ompt_scope_beginend:
        redirect_callback_mutex(acquired_nest_lock_next);

        if (endpoint == ompt_scope_begin) {
            break;
        }

    case ompt_scope_end:
        redirect_callback_mutex(release_nest_lock_prev);
        break;
    }
}

/**
 * OMPT `sync_region` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_ompt_callback_<x>_*` event. Depending on @p kind, `x` will be
 * `barrier`, `taskwait` or `taskgroup`. The postfix will be `begin` or `end`
 * depending on @p endpoint.
 *
 *
 * @param kind          the kind of mutex
 * @param endpoint      whether its the begin or end event
 * @param parallel_data unique ID of the parallel region
 * @param task_data     unique ID of the task
 * @param codeptr_ra
 */
static void on_ompt_callback_sync_region(
    ompt_sync_region_t kind,
    ompt_scope_endpoint_t endpoint,
    ompt_data_t* parallel_data,
    ompt_data_t* task_data,
    const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (endpoint) {
    case ompt_scope_beginend:
    case ompt_scope_begin:
        switch (kind) {
        case ompt_sync_region_barrier:
        case ompt_sync_region_barrier_implicit:
        case ompt_sync_region_barrier_implicit_workshare:
        case ompt_sync_region_barrier_implicit_parallel:
        case ompt_sync_region_barrier_teams:
        case ompt_sync_region_barrier_explicit:
        case ompt_sync_region_barrier_implementation:
            redirect_callback_sync_region(barrier_begin);
            break;
        case ompt_sync_region_taskwait:
            redirect_callback_sync_region(taskwait_begin);
            break;
        case ompt_sync_region_taskgroup:
            redirect_callback_sync_region(taskgroup_begin);
            break;
        case ompt_sync_region_reduction:
            break;
        }

        if (endpoint == ompt_scope_begin) {
            break;
        }

    case ompt_scope_end:
        switch (kind) {
        case ompt_sync_region_barrier:
        case ompt_sync_region_barrier_implicit:
        case ompt_sync_region_barrier_implicit_workshare:
        case ompt_sync_region_barrier_implicit_parallel:
        case ompt_sync_region_barrier_teams:
        case ompt_sync_region_barrier_explicit:
        case ompt_sync_region_barrier_implementation:
            redirect_callback_sync_region(barrier_end);
            break;
        case ompt_sync_region_taskwait:
            redirect_callback_sync_region(taskwait_end);
            break;
        case ompt_sync_region_taskgroup:
            redirect_callback_sync_region(taskgroup_end);
            break;
        case ompt_sync_region_reduction:
            break;
        }
        break;
    }
}

/**
 * OMPT `sync_region_wait` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_ompt_callback_wait_<x>_*` event. Depending on @p kind, `x` will be
 * `barrier`, `taskwait` or `taskgroup`. The postfix will be `begin` or `end`
 * depending on @p endpoint.
 *
 *
 * @param kind          the kind of mutex
 * @param endpoint      whether its the begin or end event
 * @param parallel_data unique ID of the parallel region
 * @param task_data     unique ID of the task
 * @param codeptr_ra
 */
static void on_ompt_callback_sync_region_wait(
    ompt_sync_region_t kind,
    ompt_scope_endpoint_t endpoint,
    ompt_data_t* parallel_data,
    ompt_data_t* task_data,
    const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (endpoint) {
    case ompt_scope_begin:
    case ompt_scope_beginend:
        switch (kind) {
        case ompt_sync_region_barrier:
        case ompt_sync_region_barrier_implicit:
        case ompt_sync_region_barrier_implicit_workshare:
        case ompt_sync_region_barrier_implicit_parallel:
        case ompt_sync_region_barrier_teams:
        case ompt_sync_region_barrier_explicit:
        case ompt_sync_region_barrier_implementation:
            redirect_callback_sync_region(wait_barrier_begin);
            break;
        case ompt_sync_region_taskwait:
            redirect_callback_sync_region(wait_taskwait_begin);
            break;
        case ompt_sync_region_taskgroup:
            redirect_callback_sync_region(wait_taskgroup_begin);
            break;
        case ompt_sync_region_reduction:
            break;
        }

        if (endpoint == ompt_scope_begin) {
            break;
        }

    case ompt_scope_end:
        switch (kind) {
        case ompt_sync_region_barrier:
        case ompt_sync_region_barrier_implicit:
        case ompt_sync_region_barrier_implicit_workshare:
        case ompt_sync_region_barrier_implicit_parallel:
        case ompt_sync_region_barrier_teams:
        case ompt_sync_region_barrier_explicit:
        case ompt_sync_region_barrier_implementation:
            redirect_callback_sync_region(wait_barrier_end);
            break;
        case ompt_sync_region_taskwait:
            redirect_callback_sync_region(wait_taskwait_end);
            break;
        case ompt_sync_region_taskgroup:
            redirect_callback_sync_region(wait_taskgroup_end);
            break;
        case ompt_sync_region_reduction:
            break;
        }
        break;
    }
}

/**
 * OMPT `flush` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_flush` event.
 *
 *
 * @param thread_data unique ID of the thread
 * @param codeptr_ra
 */
static void on_ompt_callback_flush(ompt_data_t* thread_data, const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_flush)
        a->callback_flush(
            a->getParallelId(),
            a->getLocationId(codeptr_ra),
            getOmptData(thread_data),
            reinterpret_cast<MustAddressType>(codeptr_ra));
}

/**
 * OMPT `cancel` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_cancel` event.
 *
 *
 * @param task_data  unique ID of the task
 * @param flags
 * @param codeptr_ra
 */
static void on_ompt_callback_cancel(ompt_data_t* task_data, int flags, const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_cancel)
        a->callback_cancel(
            a->getParallelId(),
            a->getLocationId(codeptr_ra),
            getOmptData(task_data),
            flags,
            reinterpret_cast<MustAddressType>(codeptr_ra));
}

/**
 * OMPT `implicit_task` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_ompt_callback_<x>_task_*` event. Depending on @p flags, `x` will be
 * `initial` or `implicit`. The postfix will be `begin` or `end` depending on @p
 * endpoint.
 *
 *
 * @param endpoint      whether its the begin or end event
 * @param parallel_data unique ID of the parallel region (set by this callback)
 * @param task_data     unique ID of the task (set by this callback)
 * @param team_size
 * @param thread_num
 * @param flags
 */
static void on_ompt_callback_implicit_task(
    ompt_scope_endpoint_t endpoint,
    ompt_data_t* parallel_data,
    ompt_data_t* task_data,
    unsigned int team_size,
    unsigned int thread_num,
    int flags)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (endpoint) {
    case ompt_scope_begin:
    case ompt_scope_beginend:
        task_data->value = ompt_get_unique_id();

        if (flags & ompt_task_initial) {
            if (parallel_data)
                parallel_data->value = ompt_get_unique_id();
            redirect_callback_implicit_task(initial_task_begin);
        } else {
            redirect_callback_implicit_task(implicit_task_begin);
        }

        if (endpoint == ompt_scope_begin) {
            break;
        }

    case ompt_scope_end:
        if (flags & ompt_task_initial) {
            redirect_callback_implicit_task(initial_task_end);
        } else {
            redirect_callback_implicit_task(implicit_task_end);
        }
        break;
    }
}

/**
 * OMPT `lock_init` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_ompt_callback_init_*` event. Depending on @p kind, the event will have
 * the postfix `lock` or  `nest_lock`. Other types of @p kind will be ignored.
 *
 *
 * @param kind       the kind of mutex
 * @param hint
 * @param impl
 * @param wait_id
 * @param codeptr_ra
 */
static void on_ompt_callback_lock_init(
    ompt_mutex_t kind,
    unsigned int hint,
    unsigned int impl,
    ompt_wait_id_t wait_id,
    const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (kind) {
    case ompt_mutex_lock:
        redirect_callback_mutex_acquire(init_lock);
        break;
    case ompt_mutex_nest_lock:
        redirect_callback_mutex_acquire(init_nest_lock);
        break;

    default:
        break;
    }
}

/**
 * OMPT `lock_destroy` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_ompt_callback_destroy_*` event. Depending on @p kind, the event will
 * have the postfix `lock` or  `nest_lock`. Other types of @p kind will be
 * ignored.
 *
 *
 * @param kind       the kind of mutex
 * @param wait_id
 * @param codeptr_ra
 */
static void
on_ompt_callback_lock_destroy(ompt_mutex_t kind, ompt_wait_id_t wait_id, const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (kind) {
    case ompt_mutex_lock:
        redirect_callback_mutex(destroy_lock);
        break;
    case ompt_mutex_nest_lock:
        redirect_callback_mutex(destroy_nest_lock);
        break;

    default:
        break;
    }
}

/**
 * OMPT `work` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_ompt_callback_<x>_*` event. Depending on @p kind, `x` will be `loop`,
 * `sections`, `single_in_block`, `single_others`, `distribute` or `taskloop`.
 * The postfix will be `begin` or `end` depending on @p endpoint.
 *
 *
 * @param wstype        the kind of workload
 * @param endpoint      whether its the begin or end event
 * @param parallel_data unique ID of the parallel region
 * @param task_data     unique ID of the task
 * @param count
 * @param codeptr_ra
 */
static void on_ompt_callback_work(
    ompt_work_t wstype,
    ompt_scope_endpoint_t endpoint,
    ompt_data_t* parallel_data,
    ompt_data_t* task_data,
    uint64_t count,
    const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (endpoint) {
    case ompt_scope_begin:
    case ompt_scope_beginend:
        switch (wstype) {
        case ompt_work_loop:
        case ompt_work_loop_static:
        case ompt_work_loop_dynamic:
        case ompt_work_loop_guided:
        case ompt_work_loop_other:
            redirect_callback_work(loop_begin);
            break;
        case ompt_work_sections:
            redirect_callback_work(sections_begin);
            break;
        case ompt_work_single_executor:
            redirect_callback_work(single_in_block_begin);
            break;
        case ompt_work_single_other:
            redirect_callback_work(single_others_begin);
            break;
        case ompt_work_workshare:
        case ompt_work_scope:
            break;
        case ompt_work_distribute:
            redirect_callback_work(distribute_begin);
            break;
        case ompt_work_taskloop:
            redirect_callback_work(taskloop_begin);
            break;
        }

        if (endpoint == ompt_scope_begin) {
            break;
        }

    case ompt_scope_end:
        switch (wstype) {
        case ompt_work_loop:
        case ompt_work_loop_static:
        case ompt_work_loop_dynamic:
        case ompt_work_loop_guided:
        case ompt_work_loop_other:
            redirect_callback_work(loop_end);
            break;
        case ompt_work_sections:
            redirect_callback_work(sections_end);
            break;
        case ompt_work_single_executor:
            redirect_callback_work(single_in_block_end);
            break;
        case ompt_work_single_other:
            redirect_callback_work(single_others_end);
            break;
        case ompt_work_workshare:
        case ompt_work_scope:
            break;
        case ompt_work_distribute:
            redirect_callback_work(distribute_end);
            break;
        case ompt_work_taskloop:
            redirect_callback_work(taskloop_end);
            break;
        }
        break;
    }
}

/**
 * OMPT `masked` callback handler.
 *
 * This callback handler will redirect the event to MUST as
 * `MUST_ompt_callback_masked_*` event. The postfix will be `begin` or `end`
 * depending on @p endpoint.
 *
 *
 * @param endpoint      whether its the begin or end event
 * @param parallel_data unique ID of the parallel region
 * @param task_data     unique ID of the task
 * @param codeptr_ra
 */
static void on_ompt_callback_masked(
    ompt_scope_endpoint_t endpoint,
    ompt_data_t* parallel_data,
    ompt_data_t* task_data,
    const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    switch (endpoint) {
    case ompt_scope_begin:
    case ompt_scope_beginend:
        redirect_callback_masked(masked_begin);

        if (endpoint == ompt_scope_begin) {
            break;
        }

    case ompt_scope_end:
        redirect_callback_masked(masked_end);
        break;
    }
}

/**
 * OMPT `parallel_begin` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_parallel_begin` event.
 *
 *
 * @param encountering_task_data
 * @param encountering_task_frame
 * @param parallel_data           unique ID of the parallel region (set by this
 *                                callback)
 * @param requested_team_size
 * @param flag
 * @param codeptr_ra
 */
static void on_ompt_callback_parallel_begin(
    ompt_data_t* encountering_task_data,
    const ompt_frame_t* encountering_task_frame,
    ompt_data_t* parallel_data,
    uint32_t requested_team_size,
    int flag,
    const void* codeptr_ra)
{
    if (parallel_data)
        parallel_data->value = ompt_get_unique_id();

    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_parallel_begin)
        a->callback_parallel_begin(
            a->getParallelId(),
            a->getLocationId(codeptr_ra),
            getOmptData(encountering_task_data),
            getOmptData(parallel_data),
            requested_team_size,
            flag,
            reinterpret_cast<MustAddressType>(codeptr_ra));
}

/**
 * OMPT `parallel_end` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_parallel_end` event.
 *
 *
 * @param parallel_data          unique ID of the parallel region
 * @param encountering_task_data
 * @param flag
 * @param codeptr_ra
 */
static void on_ompt_callback_parallel_end(
    ompt_data_t* parallel_data,
    ompt_data_t* encountering_task_data,
    int flag,
    const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_parallel_end)
        a->callback_parallel_end(
            a->getParallelId(),
            a->getLocationId(codeptr_ra),
            getOmptData(parallel_data),
            getOmptData(encountering_task_data),
            flag,
            reinterpret_cast<MustAddressType>(codeptr_ra));
}

/**
 * OMPT `task_create` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_task_create` event.
 *
 *
 * @param encountering_task_data
 * @param encountering_task_frame
 * @param new_task_data           unique ID of the task (set by this callback)
 * @param type
 * @param has_dependences
 * @param codeptr_ra
 */
static void on_ompt_callback_task_create(
    ompt_data_t* encountering_task_data,
    const ompt_frame_t* encountering_task_frame,
    ompt_data_t* new_task_data,
    int type,
    int has_dependences,
    const void* codeptr_ra)
{
    new_task_data->value = ompt_get_unique_id();

    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_task_create)
        a->callback_task_create(
            a->getParallelId(),
            a->getLocationId(codeptr_ra),
            getOmptData(encountering_task_data),
            getOmptData(new_task_data),
            type,
            has_dependences,
            reinterpret_cast<MustAddressType>(codeptr_ra));
}

/**
 * OMPT `task_schedule` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_task_schedule` event.
 *
 * @note If the task status is complete, the `MUST_ompt_callback_task_end` event
 *       will be sent to MUST, too.
 *
 *
 * @param first_task_data
 * @param prior_task_status
 * @param second_task_data
 */
static void on_ompt_callback_task_schedule(
    ompt_data_t* first_task_data,
    ompt_task_status_t prior_task_status,
    ompt_data_t* second_task_data)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_task_schedule)
        a->callback_task_schedule(
            a->getParallelId(),
            getOmptData(first_task_data),
            prior_task_status,
            getOmptData(second_task_data));

    /* If the task has completed, an additional task_end event will be sent.
     *
     * NOTE: Only first_task_data will be passed, as it is the only relevant
     *       information indicating the task that has ended. */
    if (prior_task_status == ompt_task_complete || prior_task_status == ompt_task_late_fulfill ||
        prior_task_status == ompt_task_cancel)
        if (a != nullptr && a->callback_task_end)
            a->callback_task_end(a->getParallelId(), getOmptData(first_task_data));
}

/**
 * OMPT `dependences` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_task_dependences` event.
 *
 *
 * @param task_data unique ID of the task
 * @param deps
 * @param ndeps
 */
static void
on_ompt_callback_dependences(ompt_data_t* task_data, const ompt_dependence_t* deps, int ndeps)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_task_dependences)
        a->callback_task_dependences(a->getParallelId(), getOmptData(task_data), deps, ndeps);
}

/**
 * OMPT `task_dependence` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_task_dependence_pair` event.
 *
 *
 * @param first_task_data
 * @param second_task_data
 */
static void
on_ompt_callback_task_dependence(ompt_data_t* first_task_data, ompt_data_t* second_task_data)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_task_dependence_pair)
        a->callback_task_dependence_pair(
            a->getParallelId(),
            getOmptData(first_task_data),
            getOmptData(second_task_data));
}

/**
 * OMPT `thread_begin` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_thread_begin` event.
 *
 *
 * @param thread_type
 * @param thread_data unique ID of the thread (will be set by callback)
 */
static void on_ompt_callback_thread_begin(ompt_thread_t thread_type, ompt_data_t* thread_data)
{
    thread_data->value = ompt_get_unique_id();

    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_thread_begin)
        a->callback_thread_begin(a->getParallelId(), thread_type, getOmptData(thread_data));
}

/**
 * OMPT `thread_end` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_thread_end` event.
 *
 *
 * @param thread_data unique ID of the thread
 */
static void on_ompt_callback_thread_end(ompt_data_t* thread_data)
{
    /* NOTE: This callback needs to check, if an instance is returned, as the
     *       MUST stack may be shutdown already, when the thread will be
     *       finished. */
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a && a->callback_thread_end)
        a->callback_thread_end(a->getParallelId(), getOmptData(thread_data));
}

/**
 * OMPT `control_tool` callback handler.
 *
 * This callback handler will passthrough the event to MUST as
 * `MUST_ompt_callback_control_tool` event.
 *
 *
 * @param command
 * @param modifier
 * @param arg
 * @param codeptr_ra
 */
static int on_ompt_callback_control_tool(
    uint64_t command,
    uint64_t modifier,
    void* arg,
    const void* codeptr_ra)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_control_tool)
        a->callback_control_tool(
            a->getParallelId(),
            a->getLocationId(codeptr_ra),
            command,
            modifier,
            arg,
            reinterpret_cast<MustAddressType>(codeptr_ra));

    return 0;
}

/**
 * OMPT `initialize` callback handler.
 *
 *
 * @param lookup
 * @param initial_device_num
 * @param tool_data
 */
static int on_ompt_callback_initialize(
    ompt_function_lookup_t lookup,
    int initial_device_num,
    ompt_data_t* tool_data)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_initialize)
        a->callback_initialize(lookup, initial_device_num, tool_data);

    return 0;
}

/**
 * OMPT `finalize` callback handler.
 *
 *
 * @param tool_data
 */
static int on_ompt_callback_finalize(ompt_data_t* tool_data)
{
    OpenMPadapter* a = OpenMPadapter::getInstance("");
    if (a != nullptr && a->callback_finalize)
        a->callback_finalize(tool_data);

    return 0;
}

/* The following code is an adapted and commented copy of the LLVM runtime OMPT
 * test case, licensed under the Apache License. */

/**
 * Register a callback function.
 *
 * This macro registers a new callback with a function matching the callback's
 * name.
 *
 *
 * @param name name of the callback
 * @param type type of the callback function
 */
#define register_callback_t(name, type)                                                            \
    do {                                                                                           \
        type f_##name = &on_##name;                                                                \
        if (ompt_set_callback(name, (ompt_callback_t)f_##name) == ompt_set_never)                  \
            must::cout << "Could not register callback '" #name "'\n";                             \
    } while (0)

/**
 * Register a callback function.
 *
 * This macro registers a new callback with a function matching the callback's
 * name.
 *
 * @note The callback needs to match a typedef with same name and `_t` postfix.
 *
 *
 * @param name name of the callback
 */
#define register_callback(name) register_callback_t(name, name##_t)

/**
 * OMPT initialize function.
 *
 * This function will be called on OpenMP runtime initialization and registers
 * all callbacks provided by this adapter.
 *
 *
 * @param lookup             Lookup OMPT internal functions with this function
 *                           pointer.
 * @param initial_device_num
 * @param tool_data
 *
 * @return This function always returns `1`, indicating that this tool is
 *         activated.
 */
static int
ompt_initialize(ompt_function_lookup_t lookup, int initial_device_num, ompt_data_t* tool_data)
{
    /* Lookup necessary functions of the OpenMP runtime to be used by this
     * adapter. These will be available in the adapter's global namespace, as
     * the related variables are global. */
    ompt_set_callback = reinterpret_cast<ompt_set_callback_t>(lookup("ompt_set_callback"));
    ompt_get_unique_id = reinterpret_cast<ompt_get_unique_id_t>(lookup("ompt_get_unique_id"));
    ompt_finalize_tool = reinterpret_cast<ompt_finalize_tool_t>(lookup("ompt_finalize_tool"));

    (*on_ompt_callback_initialize)(lookup, initial_device_num, tool_data);

    /* Register all callbacks defined in this adapter to be triggered by the
     * OpenMP runtime. */
    register_callback(ompt_callback_mutex_acquire);
    register_callback_t(ompt_callback_mutex_acquired, ompt_callback_mutex_t);
    register_callback_t(ompt_callback_mutex_released, ompt_callback_mutex_t);
    register_callback(ompt_callback_nest_lock);
    register_callback(ompt_callback_sync_region);
    register_callback_t(ompt_callback_sync_region_wait, ompt_callback_sync_region_t);
    register_callback(ompt_callback_control_tool);
    register_callback(ompt_callback_flush);
    register_callback(ompt_callback_cancel);
    register_callback(ompt_callback_implicit_task);
    register_callback_t(ompt_callback_lock_init, ompt_callback_mutex_acquire_t);
    register_callback_t(ompt_callback_lock_destroy, ompt_callback_mutex_t);
    register_callback(ompt_callback_work);
    register_callback(ompt_callback_masked);
    register_callback(ompt_callback_parallel_begin);
    register_callback(ompt_callback_parallel_end);
    register_callback(ompt_callback_task_create);
    register_callback(ompt_callback_task_schedule);
    register_callback(ompt_callback_dependences);
    register_callback(ompt_callback_task_dependence);
    register_callback(ompt_callback_thread_begin);
    register_callback(ompt_callback_thread_end);

    /* Return 1 to indicate that this tool is activated and its callbacks should
     * be called. No further tools will be activated. */
    return 1;
}

/**
 * OMPT finalize function.
 *
 * Although this function is empty, it is needed because OMPT requires a
 * finalize function to be set.
 *
 *
 * @param tool_data
 */
static void ompt_finalize(ompt_data_t* tool_data) { (*on_ompt_callback_finalize)(tool_data); }

/**
 * OMPT start tool function.
 *
 * This function is the entry point of the OpenMP tools interface and will be
 * called after loading the OpenMP runtime. It registers the @ref
 * ompt_initialize and @ref ompt_finalize functions.
 *
 *
 * @param omp_version
 * @param runtime_version
 *
 * @return pointers to @ref ompt_initialize and @ref ompt_finalize
 */
extern "C" ompt_start_tool_result_t*
ompt_start_tool(unsigned int omp_version, const char* runtime_version)
{
    // Prevent OpenMPadapter::getInstance to be called in callbacks before the MPI has been
    // initialized. Otherwise, we would create an instance outside GTI's mechanism which is to
    // avoid. This condition should not be necessary anymore once MUST makes use of the MPI Sessions
    // model. (Note: ModuleBase::getInstance is not supposed to be called outside GTI's macros.)
    if (!OpenMPadapterInitGuard::is_some_instance_created()) {
        ERROR(
            MUST_OMP,
            "The OpenMP runtime has been initialized before the call to "
            "MPI_Init or MPI_Init_thread. This is currently not supported by MUST. "
            "Checks of threaded MPI usage might not work as expected.")
        return nullptr;
    }
    static ompt_start_tool_result_t ompt_start_tool_result = {
        &ompt_initialize,
        &ompt_finalize,
        {0}};
    return &ompt_start_tool_result;
}
