/* This file is part of GTI (Generic Tool Infrastructure)
 *
 * Copyright (C)
 *  2008-2019 ZIH, Technische Universitaet Dresden, Federal Republic of Germany
 *  2008-2019 Lawrence Livermore National Laboratories, United States of America
 *  2013-2024 RWTH Aachen University, Federal Republic of Germany
 *
 * See the LICENSE file in the package base directory for details
 */

/**
 * @file ErrhandlerTracker.hpp
 *
 *  @date 05.12.23
 *  @author Sebastian Grabowski
 */

#ifndef GTI_MODULES_COMM_PROTOCOLS_ERRHANDLERTRACKER_HPP
#define GTI_MODULES_COMM_PROTOCOLS_ERRHANDLERTRACKER_HPP

#include <mpi.h>

#include <unordered_map>
#include <cassert>

#include "safe_ptr.h"

namespace gti
{

/**
 * Keeps track of attached MPI_Errorhandler and its function pointers.
 *
 * This template class supports the crashhandling code to wrap application-defined errorhandlers
 * with GTI's own errorhandlers. To be able to still execute the app's errorhandlers we need to
 * know the function pointer of the MPI_Errorhandler that is attached to an object.
 * The creation and attachment of the errorhandler is a two-step process.
 * The function pointer is only known to us in the first step, whereas the object that it is
 * attached to is only known in the second step.
 *
 * The methods are thread-safe.
 *
 * @tparam MPI_OBJECT The type of the MPI object with an attached MPI_Errhandler (e.g. MPI_Comm,
 * MPI_Win, etc.)
 */
template <typename MPI_OBJECT>
class ErrhandlerTracker
{
    /**
     * The type of the errorhandler's function pointer.
     */
    using errhandler_func_t = void(MPI_OBJECT*, int*, ...);

    /**
     * Helper to store the errorhandler's handle and function pointer that allows to retrieve the
     * function pointer by querying a single map.
     */
    struct data {
        MPI_Errhandler errhandler;
        errhandler_func_t* errhandler_func;
    };

    sf::safe_ptr<std::unordered_map<MPI_Errhandler, errhandler_func_t*>> myErrhandlerToFuncs{};
    sf::safe_ptr<std::unordered_map<MPI_OBJECT, data>> myObjToErrhandlers{};

  public:
    /**
     * Registers a errorhandler function pointer and MPI_Errhandler object pair.
     *
     * @param fp the function pointer used when creating the errorhandler object
     * @param h the handle of the created errorhandler
     */
    auto add(errhandler_func_t* fp, MPI_Errhandler h) -> void
    {
        auto xptr = sf::xlock_safe_ptr(myErrhandlerToFuncs);
        xptr->insert(std::make_pair(h, fp));
    }

    /**
     * Register the attachment of an errorhandler to an MPI object.
     *
     * @param objHandle the handle of the mpi object
     * @param errhandler the handle of the attached errorhandler
     */
    auto attach(MPI_OBJECT objHandle, MPI_Errhandler errhandler) -> void
    {
        auto xptr = sf::xlock_safe_ptr(myErrhandlerToFuncs);
        auto const iter = xptr->find(errhandler);
        assert(iter != xptr->cend());

        auto xptr2 = sf::xlock_safe_ptr(myObjToErrhandlers);
        xptr2->insert(std::make_pair(objHandle, data{errhandler, iter->second}));
    }

    /**
     * Retrieve the MPI_Errhandler attached to the given MPI object.
     *
     * @param objHandle the object to query by
     * @return the attached errorhandler or MPI_ERRHANDLER_NULL if none has been registered.
     */
    auto getErrhandler(MPI_OBJECT objHandle) const -> MPI_Errhandler
    {
        auto const sptr = sf::slock_safe_ptr(myObjToErrhandlers);
        auto const handler_it = sptr->find(objHandle);
        if (handler_it == sptr->end()) {
            return MPI_ERRHANDLER_NULL;
        }
        return handler_it->second.errhandler;
    }

    /**
     * Retrieve the function pointer of an attached errorhandler.
     *
     * @param objHandle the object to query by
     * @return the function pointer or nullptr if none has been registered.
     */
    auto getErrhandlerFunc(MPI_OBJECT objHandle) const -> errhandler_func_t*
    {
        auto const sptr = sf::slock_safe_ptr(myObjToErrhandlers);
        auto const handler_it = sptr->find(objHandle);
        if (handler_it == sptr->end()) {
            return nullptr;
        }
        return handler_it->second.errhandler_func;
    }
};

} // namespace gti

#endif // GTI_MODULES_COMM_PROTOCOLS_ERRHANDLERTRACKER_HPP
