/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file DGroupMatch.cpp
 *       @see MUST::DGroupMatch.
 *
 * @date 14.02.2025
 * @author Cornelius Pätzold
 */

#include "GtiMacros.h"
#include "MustEnums.h"
#include "PrefixedOstream.hpp"

#include "DGroupMatch.h"

#include <assert.h>
#include <sstream>
#include <fstream>

#include <vector>

#define MUST_DEBUG

using namespace must;

mGET_INSTANCE_FUNCTION(DGroupMatch)
mFREE_INSTANCE_FUNCTION(DGroupMatch)
mPNMPI_REGISTRATIONPOINT_FUNCTION(DGroupMatch)

//=============================
// Constructor
//=============================
DGroupMatch::DGroupMatch(const char* instanceName)
    : gti::ModuleBase<DGroupMatch, I_DGroupMatch>(instanceName), myPlaceId(-1), myListener(NULL),
      myQs(), myQSize(0), myMaxQSize(0)
{
    // create sub modules
    std::vector<I_Module*> subModInstances;
    subModInstances = createSubModuleInstances();

    // handle sub modules
#define NUM_MODS_REQUIRED 9
    if (subModInstances.size() < NUM_MODS_REQUIRED) {
        must::cerr << "Module has not enough sub modules, check its analysis specification! ("
                   << __FILE__ << "@" << __LINE__ << ")" << std::endl;
        assert(0);
    }
    if (subModInstances.size() > NUM_MODS_REQUIRED) {
        for (std::vector<I_Module*>::size_type i = NUM_MODS_REQUIRED; i < subModInstances.size();
             i++)
            destroySubModuleInstance(subModInstances[i]);
    }

    myPIdMod = (I_ParallelIdAnalysis*)subModInstances[0];
    myLIdMod = (I_LocationAnalysis*)subModInstances[1];
    myConsts = (I_BaseConstants*)subModInstances[2];
    myLogger = (I_CreateMessage*)subModInstances[3];
    myWTrack = (I_WinTrack*)subModInstances[4];
    myGTrack = (I_GroupTrack*)subModInstances[5];
    myCTrack = (I_CommTrack*)subModInstances[6];
    myDTrack = (I_DatatypeTrack*)subModInstances[7];
    myFloodControl = (I_FloodControl*)subModInstances[8];
    myProfiler = NULL;

    // Initialize module data
    getWrapAcrossFunction("passToGroupNoTransfer", (GTI_Fct_t*)&myPassToGroupNoTransferFunction);
    getWrapAcrossFunction("passToGroupN", (GTI_Fct_t*)&myPassToGroupNFunction);
    getWrapAcrossFunction("passToGroupTypes", (GTI_Fct_t*)&myPassToGroupTypesFunction);
    getWrapAcrossFunction("passToGroupCounts", (GTI_Fct_t*)&myPassToGroupCountsFunction);

    // Assert correct mapping
    // TODO this is somewhat crude
    int numPlacesOnLevel = 0;
    int tempTarget, lastTarget = -1;
    int rank = 0;

    while (getLevelIdForApplicationRank(rank, &tempTarget) == GTI_SUCCESS) {
        rank++;
        if (lastTarget != tempTarget)
            numPlacesOnLevel++;
    }

    if (numPlacesOnLevel > 1 && (!myPassToGroupNoTransferFunction)) {
        must::cerr
            << "ERROR: Distributed Group Matching was mapped on a layer of size > 0 while no "
               "intra layer communication was present, as a result Group matching will not be "
               "possible. Either add an intra layer communication, or map the Group matching "
               "onto a layer with a single process."
            << std::endl;
        assert(0);
    }
}

//=============================
// Destructor
//=============================
DGroupMatch::~DGroupMatch()
{
    if (myPIdMod)
        destroySubModuleInstance((I_Module*)myPIdMod);
    myPIdMod = NULL;

    if (myLIdMod)
        destroySubModuleInstance((I_Module*)myLIdMod);
    myLIdMod = NULL;

    if (myConsts)
        destroySubModuleInstance((I_Module*)myConsts);
    myConsts = NULL;

    if (myLogger)
        destroySubModuleInstance((I_Module*)myLogger);
    myLogger = NULL;

    if (myWTrack) {
        destroySubModuleInstance((I_Module*)myWTrack);
    }
    myWTrack = NULL;

    if (myGTrack) {
        myGTrack->notifyOfShutdown();
        destroySubModuleInstance((I_Module*)myGTrack);
    }
    myGTrack = NULL;

    if (myCTrack) {
        myCTrack->notifyOfShutdown();
        destroySubModuleInstance((I_Module*)myCTrack);
    }
    myCTrack = NULL;

    if (myDTrack) {
        myDTrack->notifyOfShutdown();
        destroySubModuleInstance((I_Module*)myDTrack);
    }
    myDTrack = NULL;

    if (myFloodControl)
        destroySubModuleInstance((I_Module*)myFloodControl);
    myFloodControl = NULL;

    if (myProfiler) {
        myProfiler->reportWrapperAnalysisTime("DGroupMatch", "maxEventQueue", 0, myMaxQSize);
        myProfiler->reportWrapperAnalysisTime("DGroupMatch", "finalQueueSize", 0, myQSize);
        destroySubModuleInstance((I_Module*)myProfiler);
    }
    myProfiler = NULL;

    //==Free other data
    for (auto rankIter = myQs.begin(); rankIter != myQs.end(); rankIter++) {
        for (auto qIter = rankIter->second.begin(); qIter != rankIter->second.end(); qIter++) {
            // Free the handle
            qIter->first->erase();
            // Free send operations
            for (auto sendOpIter = qIter->second.sends.begin();
                 sendOpIter != qIter->second.sends.end();
                 sendOpIter++) {
                if (*sendOpIter)
                    delete (*sendOpIter);
            }
            // Free recv operations
            for (auto recvOpIter = qIter->second.recvs.begin();
                 recvOpIter != qIter->second.recvs.end();
                 recvOpIter++) {
                if (*recvOpIter)
                    delete (*recvOpIter);
            }
        }
    }
    myQs.clear();
}

//=============================
// init
//=============================
GTI_ANALYSIS_RETURN DGroupMatch::init(MustParallelId pId)
{
    if (myPlaceId < 0)
        getLevelIdForApplicationRank(myPIdMod->getInfoForId(pId).rank, &myPlaceId);
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// post
//=============================
GTI_ANALYSIS_RETURN DGroupMatch::post(
    MustParallelId pId,
    MustLocationId lId,
    MustGroupType group,
    MustWinType win,
    int callId)
{
    //== 1) Get win and group infos
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    I_WinPersistent* wInfo;
    I_GroupPersistent* gInfo;
    if (!getWinInfo(pId, win, &wInfo) || !getGroupInfo(pId, group, &gInfo)) {
        if (wInfo)
            wInfo->erase();
        if (gInfo)
            gInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    I_CommPersistent* cInfo = wInfo->getComm();
    std::map<int, PassInfo> targetPlaces = getPlacesAndNotifyListener(
        pId,
        lId,
        cInfo,
        wInfo,
        true,
        MUST_BUFFERED_SEND,
        gInfo->getGroup()->getMapping().data(),
        gInfo->getGroup()->getMapping().size());

    for (auto& targetPlace : targetPlaces) {
        if (myPlaceId != targetPlace.first) {
            // Forward associated resources first
            myLIdMod->passLocationToPlace(pId, lId, targetPlace.first);
            MustRemoteIdType winRId;
            myWTrack->passWinAcross(rank, wInfo, targetPlace.first, &winRId);
            // Forward send information to targetPlace
            if (myPassToGroupNoTransferFunction)
                (*myPassToGroupNoTransferFunction)(
                    pId,
                    lId,
                    targetPlace.second.ranks.data(),
                    targetPlace.second.timestamps.data(),
                    targetPlace.second.ranks.size(),
                    MUST_GROUPMATCH_WIN_HANDLE,
                    winRId,
                    MUST_GROUPMATCH_START_CALLID,
                    false,
                    targetPlace.first);
        } else {
            recvPassToGroupNoTransfer(
                pId,
                lId,
                targetPlace.second.ranks.data(),
                targetPlace.second.timestamps.data(),
                targetPlace.second.ranks.size(),
                MUST_GROUPMATCH_WIN_HANDLE,
                win,
                MUST_GROUPMATCH_START_CALLID,
                false);
        }
    }

    myPostRanks[win][rank] = gInfo->getGroup()->getMapping();
    gInfo->erase();
    wInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

// =============================
// start
//=============================
GTI_ANALYSIS_RETURN DGroupMatch::start(
    MustParallelId pId,
    MustLocationId lId,
    MustGroupType group,
    MustWinType win,
    int callId)
{
    // Get infos and translation
    int rank = myPIdMod->getInfoForId(pId).rank;
    I_WinPersistent* wInfo;
    I_GroupPersistent* gInfo;
    if (!getWinInfo(pId, win, &wInfo) || !getGroupInfo(pId, group, &gInfo)) {
        if (wInfo)
            wInfo->erase();
        if (gInfo)
            gInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    I_CommPersistent* cInfo = wInfo->getComm();
    std::vector<MustLTimeStamp> timestamps;
    notifyListenerOfNewOp(
        pId,
        lId,
        cInfo,
        wInfo,
        false,
        MUST_UNKNOWN_SEND,
        gInfo->getGroup()->getMapping(),
        timestamps);

    // Create recv op for MPI_Win_start
    DGroupOp* recvOp = new DGroupOp(
        pId,
        lId,
        wInfo,
        this,
        false,
        gInfo->getGroup()->getMapping(),
        timestamps,
        MUST_GROUPMATCH_START_CALLID,
        false);

    //== 2) Process or Queue ?
    recvOp->process(rank);

    // create an operation that is semantically similar to an MPI_Waitall that waits for the ops
    // associated with the "timestamps".
    cInfo->copy();
    myListener
        ->newGroupCompletionOp(pId, lId, cInfo, timestamps.data(), timestamps.size(), false, 0);

    myStartRanks[win][rank] = gInfo->getGroup()->getMapping();
    gInfo->erase();
    wInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// complete
//=============================
GTI_ANALYSIS_RETURN
DGroupMatch::complete(MustParallelId pId, MustLocationId lId, MustWinType win, int callId)
{
    //== 1) Get win info
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    I_WinPersistent* wInfo;
    if (!getWinInfo(pId, win, &wInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }

    if (myStartRanks.find(win) == myStartRanks.end()) {
        wInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    auto toRanks = myStartRanks[win].find(rank);
    if (toRanks == myStartRanks[win].end()) {
        wInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    I_CommPersistent* cInfo = wInfo->getComm();
    std::map<int, PassInfo> targetPlaces = getPlacesAndNotifyListener(
        pId,
        lId,
        cInfo,
        wInfo,
        true,
        MUST_BUFFERED_SEND,
        toRanks->second.data(),
        toRanks->second.size());

    for (auto& targetPlace : targetPlaces) {
        if (myPlaceId != targetPlace.first) {
            // Forward associated resources first
            myLIdMod->passLocationToPlace(pId, lId, targetPlace.first);
            MustRemoteIdType winRId;
            bool test = myWTrack->passWinAcross(rank, wInfo, targetPlace.first, &winRId);
            // Forward send information to targetPlace
            if (myPassToGroupNoTransferFunction)
                (*myPassToGroupNoTransferFunction)(
                    pId,
                    lId,
                    targetPlace.second.ranks.data(),
                    targetPlace.second.timestamps.data(),
                    targetPlace.second.ranks.size(),
                    MUST_GROUPMATCH_WIN_HANDLE,
                    winRId,
                    MUST_GROUPMATCH_WAIT_CALLID,
                    false,
                    targetPlace.first);
        } else {
            recvPassToGroupNoTransfer(
                pId,
                lId,
                targetPlace.second.ranks.data(),
                targetPlace.second.timestamps.data(),
                targetPlace.second.ranks.size(),
                MUST_GROUPMATCH_WIN_HANDLE,
                win,
                MUST_GROUPMATCH_WAIT_CALLID,
                false);
        }
    }
    myStartRanks[win].erase(rank);
    wInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

// =============================
// wait
//=============================
GTI_ANALYSIS_RETURN
DGroupMatch::wait(MustParallelId pId, MustLocationId lId, MustWinType win, int callId)
{
    //== 1) Prepare the operation for this receive
    // Get infos and translation
    int rank = myPIdMod->getInfoForId(pId).rank;
    I_WinPersistent* wInfo;
    if (!getWinInfo(pId, win, &wInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }

    if (myPostRanks.find(win) == myPostRanks.end()) {
        wInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    auto sources = myPostRanks[win].find(rank);
    if (sources == myPostRanks[win].end()) {
        wInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    I_CommPersistent* cInfo = wInfo->getComm();
    std::vector<MustLTimeStamp> timestamps;
    notifyListenerOfNewOp(
        pId,
        lId,
        cInfo,
        wInfo,
        false,
        MUST_UNKNOWN_SEND,
        sources->second,
        timestamps);

    // Create recv op for MPI_Win_wait
    DGroupOp* recvOp = new DGroupOp(
        pId,
        lId,
        wInfo,
        this,
        false,
        sources->second,
        timestamps,
        MUST_GROUPMATCH_WAIT_CALLID,
        false);

    //== 2) Process or Queue ?
    // Process the recvOp, after this recvOp ceompleted the sendOp is passed along
    recvOp->process(rank);

    cInfo->copy();
    // create an operation that is semantically similar to an MPI_Waitall that waits for the ops
    // associated with the "timestamps".
    myListener
        ->newGroupCompletionOp(pId, lId, cInfo, timestamps.data(), timestamps.size(), false, 0);

    myPostRanks[win].erase(rank);
    wInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// nbrSendN
//=============================
GTI_ANALYSIS_RETURN DGroupMatch::nbrSendN(
    MustParallelId pId,
    MustLocationId lId,
    MustCommType comm,
    int callId,
    int sendcount,
    MustDatatypeType sendtype)
{
    // 1) Get comm info
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    I_CommPersistent* cInfo;
    if (!getCommInfo(pId, comm, &cInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }
    I_DatatypePersistent* dInfo;
    if (!getTypeInfo(pId, sendtype, &dInfo)) {
        cInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    // 2) Send part
    // 2.1) Get the places that the send information should be sent to and notify the listener
    std::map<int, PassInfo> targetPlaces = getPlacesAndNotifyListener(
        pId,
        lId,
        cInfo,
        NULL,
        true,
        MUST_BUFFERED_SEND,
        cInfo->getOutNeighbors().data(),
        cInfo->getOutNeighbors().size());

    // 2.2) Pass send information to the places
    for (auto& targetPlace : targetPlaces) {
        if (myPlaceId != targetPlace.first) {
            // Forward associated resources first
            myLIdMod->passLocationToPlace(pId, lId, targetPlace.first);
            MustRemoteIdType remoteId;
            myCTrack->passCommAcross(rank, cInfo, targetPlace.first, &remoteId);
            MustRemoteIdType remoteTypeId;
            myDTrack->passDatatypeAcross(rank, dInfo, targetPlace.first, &remoteTypeId);
            // Forward send information to targetPlace
            if (myPassToGroupNFunction)
                (*myPassToGroupNFunction)(
                    pId,
                    lId,
                    remoteTypeId,
                    sendcount,
                    targetPlace.second.ranks.data(),
                    targetPlace.second.timestamps.data(),
                    targetPlace.second.ranks.size(),
                    MUST_GROUPMATCH_COMM_HANDLE,
                    remoteId,
                    callId,
                    true,
                    targetPlace.first);
        } else {
            recvPassToGroupN(
                pId,
                lId,
                sendtype,
                sendcount,
                targetPlace.second.ranks.data(),
                targetPlace.second.timestamps.data(),
                targetPlace.second.ranks.size(),
                MUST_GROUPMATCH_COMM_HANDLE,
                comm,
                callId,
                true);
        }
    }
    dInfo->erase();
    cInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// nbrSendCounts
//=============================
GTI_ANALYSIS_RETURN DGroupMatch::nbrSendCounts(
    MustParallelId pId,
    MustLocationId lId,
    MustCommType comm,
    int callId,
    const int* sendcounts,
    MustDatatypeType sendtype)
{
    // 1) Get comm info
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    I_CommPersistent* cInfo;
    if (!getCommInfo(pId, comm, &cInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }
    I_DatatypePersistent* dInfo;
    if (!getTypeInfo(pId, sendtype, &dInfo)) {
        cInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    // 2) Send part
    // 2.1) Get the places that the send information should be sent to and notify the listener
    std::map<int, PassInfo> targetPlaces = getPlacesAndNotifyListener(
        pId,
        lId,
        cInfo,
        NULL,
        true,
        MUST_BUFFERED_SEND,
        cInfo->getOutNeighbors().data(),
        sendcounts,
        cInfo->getOutNeighbors().size());

    // 2.2) Pass send information to the places
    for (auto& targetPlace : targetPlaces) {
        if (myPlaceId != targetPlace.first) {
            // Forward associated resources first
            myLIdMod->passLocationToPlace(pId, lId, targetPlace.first);
            MustRemoteIdType remoteId;
            myCTrack->passCommAcross(rank, cInfo, targetPlace.first, &remoteId);
            MustRemoteIdType remoteTypeId;
            myDTrack->passDatatypeAcross(rank, dInfo, targetPlace.first, &remoteTypeId);
            // Forward send information to targetPlace
            if (myPassToGroupCountsFunction)
                (*myPassToGroupCountsFunction)(
                    pId,
                    lId,
                    remoteTypeId,
                    targetPlace.second.sendcounts.data(),
                    targetPlace.second.ranks.data(),
                    targetPlace.second.timestamps.data(),
                    targetPlace.second.ranks.size(),
                    MUST_GROUPMATCH_COMM_HANDLE,
                    remoteId,
                    callId,
                    true,
                    targetPlace.first);
        } else {
            recvPassToGroupCounts(
                pId,
                lId,
                sendtype,
                targetPlace.second.sendcounts.data(),
                targetPlace.second.ranks.data(),
                targetPlace.second.timestamps.data(),
                targetPlace.second.ranks.size(),
                MUST_GROUPMATCH_COMM_HANDLE,
                comm,
                callId,
                true);
        }
    }
    dInfo->erase();
    cInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// nbrSendTypes
//=============================
GTI_ANALYSIS_RETURN DGroupMatch::nbrSendTypes(
    MustParallelId pId,
    MustLocationId lId,
    MustCommType comm,
    int callId,
    const int* sendcounts,
    const MustDatatypeType* sendtypes)
{
    // 1) Get comm info
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    I_CommPersistent* cInfo;
    if (!getCommInfo(pId, comm, &cInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }

    // 2) Send part
    // 2.1) Get the places that the send information should be sent to and notify the listener
    std::map<int, PassInfo> targetPlaces = getPlacesAndNotifyListener(
        pId,
        lId,
        cInfo,
        NULL,
        true,
        MUST_BUFFERED_SEND,
        cInfo->getOutNeighbors().data(),
        sendcounts,
        sendtypes,
        cInfo->getOutNeighbors().size());

    // 2.2) Pass send information to the places
    for (auto& targetPlace : targetPlaces) {
        if (myPlaceId != targetPlace.first) {
            // Forward associated resources first
            myLIdMod->passLocationToPlace(pId, lId, targetPlace.first);
            MustRemoteIdType remoteId;
            myCTrack->passCommAcross(rank, cInfo, targetPlace.first, &remoteId);
            MustRemoteIdType remoteTypeId;

            for (auto& type : targetPlace.second.types) {
                I_DatatypePersistent* dInfo;
                if (!getTypeInfo(pId, type, &dInfo)) {
                    cInfo->erase();
                    return GTI_ANALYSIS_SUCCESS;
                }
                myDTrack->passDatatypeAcross(rank, dInfo, targetPlace.first, &type);
            }
            // Forward send information to targetPlace
            if (myPassToGroupTypesFunction) {
                (*myPassToGroupTypesFunction)(
                    pId,
                    lId,
                    targetPlace.second.types.data(),
                    targetPlace.second.sendcounts.data(),
                    targetPlace.second.ranks.data(),
                    targetPlace.second.timestamps.data(),
                    targetPlace.second.ranks.size(),
                    MUST_GROUPMATCH_COMM_HANDLE,
                    remoteId,
                    callId,
                    true,
                    targetPlace.first);
            }
        } else {
            recvPassToGroupTypes(
                pId,
                lId,
                targetPlace.second.types.data(),
                targetPlace.second.sendcounts.data(),
                targetPlace.second.ranks.data(),
                targetPlace.second.timestamps.data(),
                targetPlace.second.ranks.size(),
                MUST_GROUPMATCH_COMM_HANDLE,
                comm,
                callId,
                true);
        }
    }
    cInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// nbrRecvN
//=============================
GTI_ANALYSIS_RETURN DGroupMatch::nbrRecvN(
    MustParallelId pId,
    MustLocationId lId,
    MustCommType comm,
    int callId,
    int recvcount,
    MustDatatypeType recvtype,
    bool hasRequest,
    MustRequestType request)
{
    // 1) Get comm info
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    I_CommPersistent* cInfo;
    if (!getCommInfo(pId, comm, &cInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }

    I_DatatypePersistent* dInfo;
    if (!getTypeInfo(pId, recvtype, &dInfo)) {
        cInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    // 2) Recv part
    // 2.1) Notify the listener and collect the timestamps for each created DWaitState operation

    std::vector<MustLTimeStamp> timestamps;
    notifyListenerOfNewOp(
        pId,
        lId,
        cInfo,
        NULL,
        false,
        MUST_UNKNOWN_SEND,
        cInfo->getInNeighbors(),
        timestamps);

    // 2.2) Create groupMatch operation
    DGroupOp* recvOp = new DGroupOp(
        pId,
        lId,
        cInfo,
        this,
        false,
        cInfo->getInNeighbors(),
        timestamps,
        recvtype,
        dInfo,
        recvcount,
        callId,
        true);

    // 2.3 process the groupMatch operation
    recvOp->process(rank);

    // create an operation that is semantically similar to an MPI_Waitall that waits for the ops
    // associated with the "timestamps". If this has a request then we create a "non-blocking
    // waitall" that later is handled by an actual MPI_Wait etc.
    cInfo->copy();
    myListener->newGroupCompletionOp(
        pId,
        lId,
        cInfo,
        timestamps.data(),
        timestamps.size(),
        hasRequest,
        request);
    cInfo->erase();
    dInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// nbrRecvCounts
//=============================
GTI_ANALYSIS_RETURN DGroupMatch::nbrRecvCounts(
    MustParallelId pId,
    MustLocationId lId,
    MustCommType comm,
    int callId,
    const int* recvcounts,
    MustDatatypeType recvtype,
    bool hasRequest,
    MustRequestType request)
{
    // 1) Get comm info
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    I_CommPersistent* cInfo;
    if (!getCommInfo(pId, comm, &cInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }

    I_DatatypePersistent* dInfo;
    if (!getTypeInfo(pId, recvtype, &dInfo)) {
        cInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    // 2) Recv part
    // 2.1) Notify the listener and collect the timestamps for each created DWaitState operation
    std::vector<MustLTimeStamp> timestamps;
    notifyListenerOfNewOp(
        pId,
        lId,
        cInfo,
        NULL,
        false,
        MUST_UNKNOWN_SEND,
        cInfo->getInNeighbors(),
        timestamps);

    // 2.2) Create groupMatch operation
    DGroupOp* recvOp = new DGroupOp(
        pId,
        lId,
        cInfo,
        this,
        false,
        cInfo->getInNeighbors(),
        timestamps,
        recvtype,
        dInfo,
        std::vector<int>(recvcounts, recvcounts + cInfo->getInNeighborsCount()),
        callId,
        true);

    // 2.3 process the groupMatch operation
    recvOp->process(rank);

    // create an operation that is semantically similar to an MPI_Waitall that waits for the ops
    // associated with the "timestamps". If this has a request then we create a "non-blocking
    // waitall" that later is handled by an actual MPI_Wait etc.
    cInfo->copy();
    myListener->newGroupCompletionOp(
        pId,
        lId,
        cInfo,
        timestamps.data(),
        timestamps.size(),
        hasRequest,
        request);
    cInfo->erase();
    dInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// nbrRecvTypes
//=============================
GTI_ANALYSIS_RETURN DGroupMatch::nbrRecvTypes(
    MustParallelId pId,
    MustLocationId lId,
    MustCommType comm,
    int callId,
    const int* recvcounts,
    const MustDatatypeType* recvtypes,
    bool hasRequest,
    MustRequestType request)
{
    // 1) Get comm info
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    I_CommPersistent* cInfo;
    if (!getCommInfo(pId, comm, &cInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }
    cInfo->copy();

    std::vector<I_DatatypePersistent*> typeInfos;
    for (int i = 0; i < cInfo->getInNeighborsCount(); i++) {
        I_DatatypePersistent* dInfo;
        if (!getTypeInfo(pId, recvtypes[i], &dInfo)) {
            cInfo->erase();
            return GTI_ANALYSIS_SUCCESS;
        }
        dInfo->copy();
        typeInfos.emplace_back(dInfo);
    }

    // 2) Recv part
    // 2.1) Notify the listener and collect the timestamps for each created DWaitState operation
    std::vector<MustLTimeStamp> timestamps;
    notifyListenerOfNewOp(
        pId,
        lId,
        cInfo,
        NULL,
        false,
        MUST_UNKNOWN_SEND,
        cInfo->getInNeighbors(),
        timestamps);

    // 2.2) Create groupMatch operation
    DGroupOp* recvOp = new DGroupOp(
        pId,
        lId,
        cInfo,
        this,
        false,
        cInfo->getInNeighbors(),
        timestamps,
        std::vector<MustDatatypeType>(recvtypes, recvtypes + cInfo->getInNeighborsCount()),
        typeInfos,
        std::vector<int>(recvcounts, recvcounts + cInfo->getInNeighborsCount()),
        callId,
        true);

    // 2.3 process the groupMatch operation
    recvOp->process(rank);

    // create an operation that is semantically similar to an MPI_Waitall that waits for the ops
    // associated with the "timestamps". If this has a request then we create a "non-blocking
    // waitall" that later is handled by an actual MPI_Wait etc.
    cInfo->copy();
    myListener->newGroupCompletionOp(
        pId,
        lId,
        cInfo,
        timestamps.data(),
        timestamps.size(),
        hasRequest,
        request);

    cInfo->erase();
    for (auto& dInfo : typeInfos)
        dInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// recvPassToGroupNoTransfer
//=============================
GTI_ANALYSIS_RETURN DGroupMatch::recvPassToGroupNoTransfer(
    MustParallelId pId,
    MustLocationId lId,
    const int* worldRanks,
    const MustLTimeStamp* timestamps,
    int worldRanksCount,
    int handleType,
    MustRemoteIdType remoteHandleId,
    int callId,
    bool isCollective)
{
    int rank = myPIdMod->getInfoForId(pId).rank;
    I_Destructable* handleInfo;
    if (!getPassedHandleInfo(
            pId,
            remoteHandleId,
            (MUST_GROUPMATCH_HANDLE)handleType,
            &handleInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }

    DGroupOp* newOp = new DGroupOp(
        pId,
        lId,
        handleInfo,
        this,
        true,
        std::vector<int>(worldRanks, worldRanks + worldRanksCount),
        std::vector<MustLTimeStamp>(timestamps, timestamps + worldRanksCount),
        callId,
        isCollective);

    newOp->process(rank);
    handleInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

GTI_ANALYSIS_RETURN DGroupMatch::recvPassToGroupN(
    MustParallelId pId,
    MustLocationId lId,
    MustDatatypeType type,
    int sendCount,
    const int* worldRanks,
    const MustLTimeStamp* timestamps,
    int count,
    int handleType,
    MustRemoteIdType remoteHandleId,
    int callId,
    bool isCollective)
{

    int rank = myPIdMod->getInfoForId(pId).rank;
    I_Destructable* handleInfo;
    if (!getPassedHandleInfo(
            pId,
            remoteHandleId,
            (MUST_GROUPMATCH_HANDLE)handleType,
            &handleInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }

    I_DatatypePersistent* dInfo;
    if (!getTypeInfo(pId, type, &dInfo)) {
        handleInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    DGroupOp* newOp = new DGroupOp(
        pId,
        lId,
        handleInfo,
        this,
        true,
        std::vector<int>(worldRanks, worldRanks + count),
        std::vector<MustLTimeStamp>(timestamps, timestamps + count),
        type,
        dInfo,
        sendCount,
        callId,
        isCollective);
    newOp->process(rank);

    handleInfo->erase();
    dInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

GTI_ANALYSIS_RETURN DGroupMatch::recvPassToGroupCounts(
    MustParallelId pId,
    MustLocationId lId,
    MustDatatypeType type,
    const int* sendCounts,
    const int* worldRanks,
    const MustLTimeStamp* timestamps,
    int count,
    int handleType,
    MustRemoteIdType remoteHandleId,
    int callId,
    bool isCollective)
{

    int rank = myPIdMod->getInfoForId(pId).rank;
    I_Destructable* handleInfo;
    if (!getPassedHandleInfo(
            pId,
            remoteHandleId,
            (MUST_GROUPMATCH_HANDLE)handleType,
            &handleInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }

    I_DatatypePersistent* dInfo;
    if (!getTypeInfo(pId, type, &dInfo)) {
        handleInfo->erase();
        return GTI_ANALYSIS_SUCCESS;
    }

    DGroupOp* newOp = new DGroupOp(
        pId,
        lId,
        handleInfo,
        this,
        true,
        std::vector<int>(worldRanks, worldRanks + count),
        std::vector<MustLTimeStamp>(timestamps, timestamps + count),
        type,
        dInfo,
        std::vector<int>(sendCounts, sendCounts + count),
        callId,
        isCollective);

    newOp->process(rank);

    handleInfo->erase();
    dInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

GTI_ANALYSIS_RETURN DGroupMatch::recvPassToGroupTypes(
    MustParallelId pId,
    MustLocationId lId,
    MustDatatypeType* types,
    const int* sendCounts,
    const int* worldRanks,
    const MustLTimeStamp* timestamps,
    int count,
    int handleType,
    MustRemoteIdType remoteHandleId,
    int callId,
    bool isCollective)
{
    int rank = myPIdMod->getInfoForId(pId).rank;
    I_Destructable* handleInfo;
    if (!getPassedHandleInfo(
            pId,
            remoteHandleId,
            (MUST_GROUPMATCH_HANDLE)handleType,
            &handleInfo)) {
        return GTI_ANALYSIS_SUCCESS;
    }

    std::vector<I_DatatypePersistent*> typeInfos;
    for (int i = 0; i < count; i++) {
        I_DatatypePersistent* dInfo;
        if (!getTypeInfo(pId, types[i], &dInfo)) {
            handleInfo->erase();
            return GTI_ANALYSIS_SUCCESS;
        }
        typeInfos.emplace_back(dInfo);
    }

    DGroupOp* newOp = new DGroupOp(
        pId,
        lId,
        handleInfo,
        this,
        true,
        std::vector<int>(worldRanks, worldRanks + count),
        std::vector<MustLTimeStamp>(timestamps, timestamps + count),
        std::vector<MustDatatypeType>(types, types + count),
        typeInfos,
        std::vector<int>(sendCounts, sendCounts + count),
        callId,
        isCollective);

    newOp->process(rank);

    handleInfo->erase();
    for (auto& dInfo : typeInfos)
        dInfo->erase();
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// getWinInfo
//=============================
bool DGroupMatch::getWinInfo(MustParallelId pId, MustWinType win, I_WinPersistent** pOutWin)
{
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    bool isFromRemotePlace = (myPlaceId != sourcePlace);
    I_WinPersistent* wInfo = NULL;
    if (!isFromRemotePlace) {
        wInfo = myWTrack->getPersistentWin(pId, win);
    } else {
        wInfo = myWTrack->getPersistentRemoteWin(pId, win);
    }

    if (wInfo == NULL) {
        return false; // Unknown win
    }
    if (wInfo->isNull()) {
        wInfo->erase();
        return false; // MPI_Win_Null
    }

    if (pOutWin)
        *pOutWin = wInfo;
    return true;
}

//=============================
// getGroupInfo
//=============================
bool DGroupMatch::getGroupInfo(
    MustParallelId pId,
    MustGroupType group,
    I_GroupPersistent** pOutGroup)
{
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    bool isFromRemotePlace = (myPlaceId != sourcePlace);
    I_GroupPersistent* gInfo = NULL;
    if (group)
        gInfo = myGTrack->getPersistentGroup(rank, group);
    if (!gInfo)
        return false; // Unknown group
    if (gInfo->isNull()) {
        gInfo->erase();
        return false; // MPI_Group_Null
    }

    if (pOutGroup)
        *pOutGroup = gInfo;

    return true;
}

bool DGroupMatch::getTypeInfo(
    MustParallelId pId,
    MustDatatypeType type,
    I_DatatypePersistent** pOutType)
{
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    bool isFromRemotePlace = (myPlaceId != sourcePlace);
    I_DatatypePersistent* dInfo = NULL;
    if (!isFromRemotePlace) {
        dInfo = myDTrack->getPersistentDatatype(pId, type);
    } else {
        dInfo = myDTrack->getPersistentRemoteDatatype(pId, type);
    }

    if (dInfo == NULL) {
        return false; // Unknown datatype
    }

    if (pOutType)
        *pOutType = dInfo;
    return true;
}

//=============================
// getCommInfo
//=============================
bool DGroupMatch::getCommInfo(MustParallelId pId, MustCommType comm, I_CommPersistent** pOutComm)
{
    int rank = myPIdMod->getInfoForId(pId).rank;
    int sourcePlace;
    getLevelIdForApplicationRank(rank, &sourcePlace);
    bool isFromRemotePlace = (myPlaceId != sourcePlace);
    I_CommPersistent* cInfo = NULL;

    if (!isFromRemotePlace) {
        cInfo = myCTrack->getPersistentComm(pId, comm);
    } else {
        cInfo = myCTrack->getPersistentRemoteComm(pId, comm);
    }

    if (cInfo == NULL)
        return false; // Unknown comm

    if (cInfo->isNull()) {
        cInfo->erase();
        return false; // MPI_Comm_Null
    }

    if (pOutComm)
        *pOutComm = cInfo;

    return true;
}

bool DGroupMatch::getPassedHandleInfo(
    MustParallelId pId,
    MustRemoteIdType remoteHandleId,
    MUST_GROUPMATCH_HANDLE handleType,
    I_Destructable** pOutInfo)
{
    switch ((MUST_GROUPMATCH_HANDLE)handleType) {
    case MUST_GROUPMATCH_WIN_HANDLE:
        I_WinPersistent* winInfo;
        if (!getWinInfo(pId, remoteHandleId, &winInfo)) {
            return false;
        }
        if (pOutInfo)
            *pOutInfo = winInfo;
        break;
    case MUST_GROUPMATCH_COMM_HANDLE:
        I_CommPersistent* commInfo;
        if (!getCommInfo(pId, remoteHandleId, &commInfo)) {
            return false;
        }
        if (pOutInfo)
            *pOutInfo = commInfo;
        break;
    }
    return true;
}

//=============================
// findMatchingOp
//=============================
bool DGroupMatch::findMatchingOp(DGroupOp* op)
{
    // process all ranks this operation should be notified from
    for (auto currentPartnerIt = op->getPartnerInfos().begin();
         currentPartnerIt != op->getPartnerInfos().end();
         currentPartnerIt++) {

        //==1) Get the process that would need to call the send/MPI_Win_post
        QT::iterator q = myQs.find((*currentPartnerIt)->myTargetRank);
        if (q == myQs.end()) {
            continue;
        }

        //==2) Get the table
        /*
         * Win handles may be different on each process,
         * so we have to ask the win tracker whether any
         * of the wins present for the destination are equal
         * to the given win.
         */
        ProcessQueues::iterator table;
        for (table = q->second.begin(); table != q->second.end(); table++) {
            if (op->compareHandles(table->first))
                break;
        }
        if (table == q->second.end()) {
            continue;
        }

        //==3) search for a matching operation
        // find a matching recv or send?
        RankOutstandings* searchList =
            op->isSend() ? &table->second.recvs  // find a matching recv to our send
                         : &table->second.sends; // find a matching send to our recv
        // Search through all ops with matching callId
        for (RankOutstandings::iterator otherOp = searchList->begin(); otherOp != searchList->end();
             otherOp++) {
            std::vector<DGroupOpPartnerInfo*>::iterator it;
            for (it = (*otherOp)->getPartnerInfos().begin();
                 it != (*otherOp)->getPartnerInfos().end();
                 it++) {
                if ((*it)->myTargetRank == op->getIssuerRank())
                    break;
            }
            // Is the current rank part of otherOp?
            if (it == (*otherOp)->getPartnerInfos().end())
                continue;
            // Are the current op and the otherOp both collectives or both not collectives?
            if (op->isCollective() != (*otherOp)->isCollective())
                continue;
            if (op->isCollective()) {
                // Neighborhood collectives
                if (op->getCallId() != (*otherOp)->getCallId()) {
                    op->printCollectiveMismatch(*otherOp);
                }
            } else {
                // PSCW
                // Match MPI_Win_post with MPI_Win_start, and MPI_Win_complete with MPI_Win_wait
                // If the callIds are not matching keep looking
                if (op->getCallId() != (*otherOp)->getCallId())
                    continue;
            }
            // Match types
            (*currentPartnerIt)->matchTypes(*it);

            // HIT! remove it from otherOps PartnerInfos
            if (myListener) {
                if (op->isSend()) {
                    myListener->notifyP2PRecvMatch(
                        (*otherOp)->getPId(),
                        (*it)->myTimestamp,
                        op->getPId(),
                        (*currentPartnerIt)->myTimestamp);
                } else {
                    myListener->notifyP2PRecvMatch(
                        op->getPId(),
                        (*currentPartnerIt)->myTimestamp,
                        (*otherOp)->getPId(),
                        (*it)->myTimestamp);
                }
            }

            (*otherOp)->getPartnerInfos().erase(it);
            currentPartnerIt = op->getPartnerInfos().erase(currentPartnerIt);
            currentPartnerIt--;
            if ((*otherOp)->getPartnerInfos().empty()) {
                // send operation was completed and is not needed anymore
                searchList->erase(otherOp);
                delete (*otherOp);
            }
            // We only use the first matching op
            break;
        }
    }

    return op->getPartnerInfos().empty();
}

//=============================
// addOutstandingOp
//=============================
void DGroupMatch::addOutstandingOp(DGroupOp* op)
{
    //==1) Get the process for the receiving process
    QT::iterator q = myQs.emplace(op->getIssuerRank(), ProcessQueues()).first;

    //==2) Get the table
    // Check if the persistent handle is already present
    ProcessQueues::iterator t = q->second.find(op->getPersistentHandle());
    if (t == q->second.end())
        // if handle not present add it and increase its ref count
        t = q->second.emplace(op->getPersistentHandleCopy(), ProcessTable()).first;

    //==3) Get list
    RankOutstandings* r = op->isSend() ? &t->second.sends  // send op
                                       : &t->second.recvs; // recv op

    //==4) Add to the list
    r->push_back(op);
}

//=============================
// registerListener
//=============================
void DGroupMatch::registerListener(I_DGroupListener* listener) { myListener = listener; }

void DGroupMatch::notifyListenerOfNewOp(
    MustParallelId pId,
    MustLocationId lId,
    I_CommPersistent* cInfo,
    I_WinPersistent* wInfo,
    bool isSend,
    MustSendMode sendMode,
    std::vector<int> ranksToNotify,
    std::vector<MustLTimeStamp>& pOutTimestamps)
{
    MustLTimeStamp eventLTime = 0;
    for (auto& rank : ranksToNotify) {
        eventLTime = 0;
        if (myListener) {
            if (cInfo)
                cInfo->copy();
            if (wInfo)
                wInfo->copy();
            eventLTime = myListener->newGroupOp(pId, lId, cInfo, wInfo, isSend, rank, sendMode);
        }
        pOutTimestamps.emplace_back(eventLTime);
    }
}

template <typename Inserter>
std::map<int, PassInfo> DGroupMatch::getPlacesAndNotifyListenerInternal(
    MustParallelId pId,
    MustLocationId lId,
    I_CommPersistent* cInfo,
    I_WinPersistent* wInfo,
    bool isSend,
    MustSendMode sendMode,
    const int* ranksToNotify,
    size_t size,
    Inserter insertExtraFields)
{
    std::map<int, PassInfo> places;
    int tempPlace = 0;
    int rank = 0;
    MustLTimeStamp eventLTime = 0;
    for (int i = 0; i < size; i++) {
        eventLTime = 0;
        rank = ranksToNotify[i];
        if (myListener) {
            if (cInfo)
                cInfo->copy();
            if (wInfo)
                wInfo->copy();
            eventLTime = myListener->newGroupOp(pId, lId, cInfo, wInfo, isSend, rank, sendMode);
        };
        getLevelIdForApplicationRank(rank, &tempPlace);
        auto& info = places[tempPlace];
        info.ranks.emplace_back(rank);
        info.timestamps.emplace_back(eventLTime);
        insertExtraFields(info, i);
    }
    return places;
}

std::map<int, PassInfo> DGroupMatch::getPlacesAndNotifyListener(
    MustParallelId pId,
    MustLocationId lId,
    I_CommPersistent* cInfo,
    I_WinPersistent* wInfo,
    bool isSend,
    MustSendMode sendMode,
    const int* ranksToNotify,
    size_t size)
{

    return getPlacesAndNotifyListenerInternal(
        pId,
        lId,
        cInfo,
        wInfo,
        isSend,
        sendMode,
        ranksToNotify,
        size,
        [](PassInfo& info, int index) {
            // Nothing to do
        });
}

std::map<int, PassInfo> DGroupMatch::getPlacesAndNotifyListener(
    MustParallelId pId,
    MustLocationId lId,
    I_CommPersistent* cInfo,
    I_WinPersistent* wInfo,
    bool isSend,
    MustSendMode sendMode,
    const int* ranksToNotify,
    const int* sendCounts,
    size_t size)
{
    return getPlacesAndNotifyListenerInternal(
        pId,
        lId,
        cInfo,
        wInfo,
        isSend,
        sendMode,
        ranksToNotify,
        size,
        [sendCounts](PassInfo& info, int index) {
            info.sendcounts.emplace_back(sendCounts[index]);
        });
}

std::map<int, PassInfo> DGroupMatch::getPlacesAndNotifyListener(
    MustParallelId pId,
    MustLocationId lId,
    I_CommPersistent* cInfo,
    I_WinPersistent* wInfo,
    bool isSend,
    MustSendMode sendMode,
    const int* ranksToNotify,
    const int* sendCounts,
    const MustDatatypeType* types,
    size_t size)
{

    return getPlacesAndNotifyListenerInternal(
        pId,
        lId,
        cInfo,
        wInfo,
        isSend,
        sendMode,
        ranksToNotify,
        size,
        [sendCounts, types](PassInfo& info, int index) {
            info.sendcounts.emplace_back(sendCounts[index]);
            info.types.emplace_back(types[index]);
        });
}

/*EOF*/
