/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file WinTrack.cpp
 *       @see MUST::WinTrack.
 *
 *  @date 26.04.2017
 *  @author Simon Schwitanski
 */

#include "GtiMacros.h"

#include "WinTrack.h"

#include <sstream>

using namespace must;

mGET_INSTANCE_FUNCTION(WinTrack)
mFREE_INSTANCE_FUNCTION(WinTrack)
mPNMPI_REGISTRATIONPOINT_FUNCTION(WinTrack)

//=============================
// Constructor
//=============================
WinTrack::WinTrack(const char* instanceName)
    : TrackBase<Win, I_Win, MustWinType, MustMpiWinPredefined, WinTrack, I_WinTrack>(instanceName)
{
    // Get the DatatypeTrack and CommTrack modules
    if (myFurtherMods.size() < 4) {
        std::cerr << "Error: the WinTrack module needs the DatatypeTrack and CommTrack modules as "
                     "childs, but at least one of them was not available."
                  << std::endl;
        assert(0);
    }

    myDTrack = (I_DatatypeTrack*)myFurtherMods[0];
    myCTrack = (I_CommTrack*)myFurtherMods[1];
    myGTrack = (I_GroupTrack*)myFurtherMods[2];
    myConsts = (I_BaseConstants*)myFurtherMods[3];

    // Initialize module data
    getWrapAcrossFunction("passWinAcross", (GTI_Fct_t*)&myPassWinAcrossFunc);
    getWrapAcrossFunction("passFreeWinAcross", (GTI_Fct_t*)&myPassFreeWinAcrossFunc);
}

//=============================
// freeWin
//=============================
GTI_ANALYSIS_RETURN WinTrack::freeWin(MustParallelId pId, MustLocationId lId, MustWinType win)
{

    // find window
    Win* info = getHandleInfo(pId, win);
    if (!info)
        return GTI_ANALYSIS_SUCCESS;

    // remove window
    // TODO: we cannot remove the handle here, because we cannot be sure that the
    // window information will not be needed anywhere
    removeUserHandle(pId, win);
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// attachWin
//=============================
GTI_ANALYSIS_RETURN WinTrack::attachWin(
    MustParallelId pId,
    MustLocationId lId,
    MustAddressType base,
    int size,
    MustWinType win)
{

    // find window
    Win* info = getHandleInfo(pId, win);
    if (!info)
        return GTI_ANALYSIS_SUCCESS;
    addMemoryInterval(*info, (MustAddressType)base, size);

    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// detachWin
//=============================
GTI_ANALYSIS_RETURN
WinTrack::detachWin(MustParallelId pId, MustLocationId lId, MustAddressType base, MustWinType win)
{

    // find window
    Win* info = getHandleInfo(pId, win);
    if (!info)
        return GTI_ANALYSIS_SUCCESS;

    // search for memory block and remove it (O(n), efficiency?)
    for (MustMemIntervalListType::iterator it = info->myMemIntervals.begin();
         it != info->myMemIntervals.end();
         ++it) {
        if (it->baseAddress == base) {
            info->myMemIntervals.erase(it);
            break;
        }
    }

    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// createPredefinedInfo
//=============================
Win* WinTrack::createPredefinedInfo(int value, MustWinType handle)
{
    if (handle == myNullValue)
        return new Win();
    return NULL; // There should not be any other cases
}

//============================
// addWin
//============================
GTI_ANALYSIS_RETURN WinTrack::addWin(
    MustParallelId pId,
    MustLocationId lId,
    int kind,
    int memoryModel,
    MustCommType comm,
    void* base,
    int size,
    int disp_unit,
    MustWinType win)
{

    Win* info = new Win();
    info->myKind = (MUST_WIN_KIND)kind;
    info->myMemoryModel = (MUST_WIN_MEMORY_MODEL)memoryModel;
    info->myCreationPId = pId;
    info->myCreationLId = lId;
    info->myComm = myCTrack->getPersistentComm(pId, comm);
    info->myCommHandle = comm;
    info->myContextId = info->myComm->getNextContextId();
    info->myIsNull = false;

    if (info->myKind == MUST_WIN_DYNAMIC) {
        // dynamic window have MPI_BOTTOM as base address
        info->myBase = (MustAddressType)myConsts->getBottom();
    } else {
        info->myBase = (MustAddressType)base;
    }

    info->myDispUnit = disp_unit;

#ifdef MUST_DEBUG
    if (info->myMemoryModel == MUST_WIN_MEMORY_UNKNOWN)
        std::cout << "Warning: Memory model of RMA window is unknown!" << std::endl;
    std::cout << "window base: " << info->myBase << std::endl;
#endif

    // add memory region corresponding to window
    if (info->myKind != MUST_WIN_DYNAMIC) {
        addMemoryInterval(*info, info->myBase, size);
    }

    submitUserHandle(pId, win, info);

    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// addMemoryInterval
//=============================
void WinTrack::addMemoryInterval(Win& win, MustAddressType base, MustAddressType size)
{
    win.myMemIntervals.insert(memInterval(
        StridedBlock(base, base, true, 0, 1, size, 0),
        0,
        (MustRequestType)0,
        true,
        NULL,
        base,
        0));
}

//=============================
// addRemoteWin
//=============================
GTI_ANALYSIS_RETURN WinTrack::addRemoteWin(
    int rank,
    int hasHandle,
    MustWinType winHandle,
    MustRemoteIdType remoteId,
    int kind,
    int memoryModel,
    MustRemoteIdType commId,
    MustAddressType base,
    int dispUnit,
    unsigned long long contextId,
    MustParallelId creationPId,
    MustLocationId creationLId)
{
    // create new win
    Win* info = new Win();
    info->myKind = (MUST_WIN_KIND)kind;
    info->myMemoryModel = (MUST_WIN_MEMORY_MODEL)memoryModel;
    info->myComm = myCTrack->getPersistentRemoteComm(rank, commId);
    info->myCommHandle = 0;
    info->myCreationPId = creationPId;
    info->myCreationLId = creationLId;
    info->myContextId = contextId;
    info->myBase = base;
    info->myDispUnit = dispUnit;
    info->myIsNull = false;

    // register the new remote win
    submitRemoteResource(rank, remoteId, hasHandle, winHandle, info);

    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// freeRemoteWin
//=============================
GTI_ANALYSIS_RETURN WinTrack::freeRemoteWin(int rank, MustRemoteIdType remoteId)
{
    removeRemoteResource(rank, remoteId);
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// getWin
//=============================
I_Win* WinTrack::getWin(MustParallelId pId, MustWinType win) { return getWin(pId2Rank(pId), win); }

//=============================
// getWin
//=============================
I_Win* WinTrack::getWin(int rank, MustWinType win) { return getHandleInfo(rank, win); }

//=============================
// getPersistentWin
//=============================
I_WinPersistent* WinTrack::getPersistentWin(MustParallelId pId, MustWinType win)
{
    return getPersistentWin(pId2Rank(pId), win);
}

//=============================
// getPersistentWin
//=============================
I_WinPersistent* WinTrack::getPersistentWin(int rank, MustWinType win)
{
    Win* ret = getHandleInfo(rank, win);
    ;
    if (ret)
        ret->incRefCount();
    return ret;
}

//=============================
// getRemoteWin
//=============================
I_Win* WinTrack::getRemoteWin(MustParallelId pId, MustRemoteIdType remoteId)
{
    return getRemoteWin(pId2Rank(pId), remoteId);
}

//=============================
// getRemoteWin
//=============================
I_Win* WinTrack::getRemoteWin(int rank, MustRemoteIdType remoteId)
{
    Win* ret = getRemoteIdInfo(rank, remoteId);
    return ret;
}

//=============================
// getPersistentRemoteWin
//=============================
I_WinPersistent* WinTrack::getPersistentRemoteWin(MustParallelId pId, MustRemoteIdType remoteId)
{
    return getPersistentRemoteWin(pId2Rank(pId), remoteId);
}

//=============================
// getPersistentRemoteWin
//=============================
I_WinPersistent* WinTrack::getPersistentRemoteWin(int rank, MustRemoteIdType remoteId)
{
    Win* ret = getRemoteIdInfo(rank, remoteId);
    if (ret)
        ret->incRefCount();
    return ret;
}

MustWinType WinTrack::getMatchingWin(int remoteRank, int targetRank, MustRemoteIdType remoteId)
{

    I_Win* remoteWin = getRemoteWin(remoteRank, remoteId);

    std::list<std::pair<int, MustWinType>> handles = getUserHandles();
    for (std::list<std::pair<int, MustWinType>>::iterator it = handles.begin(); it != handles.end();
         ++it) {
        if (it->first == targetRank && *remoteWin == *getWin(it->first, it->second))
            return it->second;
    }

    assert(0);
    return 0;
}

//=============================
// passWinAcross
//=============================
bool WinTrack::passWinAcross(MustParallelId pId, MustWinType win, int toPlaceId)
{
    return passWinAcross(pId2Rank(pId), win, toPlaceId);
}

//=============================
// passWinAcross
//=============================
bool WinTrack::passWinAcross(int rank, MustWinType winHandle, int toPlaceId)
{
    // Get win
    Win* win = getHandleInfo(rank, winHandle);

    // Use the existing passWinAcross
    return passWinAcrossInternal(rank, win, toPlaceId, NULL, true, winHandle);
}

//=============================
// passWinAcross
//=============================
bool WinTrack::passWinAcross(int rank, I_Win* winIn, int toPlaceId, MustRemoteIdType* pOutRemoteId)
{
    if (!winIn)
        return false; // Invalid win

    // Cast to internal representation
    Win* win = (Win*)winIn;

    // Do we still have a handle associated?
    MustWinType handle = 0;
    bool hasHandle = getHandleForInfo(rank, win, &handle);

    return passWinAcrossInternal(rank, win, toPlaceId, pOutRemoteId, hasHandle, handle);
}

//=============================
// passWinAcrossInternal
//=============================
bool WinTrack::passWinAcrossInternal(
    int rank,
    Win* win,
    int toPlaceId,
    MustRemoteIdType* pOutRemoteId,
    bool hasHandle,
    MustWinType handle)
{
    // Do we have wrap-across at all?
    if (!myPassWinAcrossFunc)
        return false;

    // Valid info?
    if (!win)
        return false;

    // Store the remote id
    if (pOutRemoteId)
        *pOutRemoteId = win->getRemoteId();

    // Did we already pass this win?
    if (win->wasForwardedToPlace(toPlaceId, rank))
        return true;

    // Pass base resources of the win
    myLIdMod->passLocationToPlace(win->myCreationPId, win->myCreationLId, toPlaceId);

    MustRemoteIdType commId = 0;
    myCTrack->passCommAcross(rank, win->myComm, toPlaceId, &commId);

    // Pass the actual win across
    (*myPassWinAcrossFunc)(
        rank,
        (int)hasHandle,
        handle,
        win->getRemoteId(),
        (int)win->myKind,
        (int)win->myMemoryModel,
        commId,
        win->myBase,
        win->myDispUnit,
        win->myContextId,
        win->myCreationPId,
        win->myCreationLId,
        toPlaceId);

    // Tell the win that we passed it across
    win->setForwardedToPlace(toPlaceId, rank, myPassFreeWinAcrossFunc);

    return true;
}

//=============================
// startEpoch
//=============================
GTI_ANALYSIS_RETURN
WinTrack::startEpoch(
    MustParallelId pId,
    MustLocationId lId,
    int targetRank,
    int syncKind,
    int assert,
    MustWinType win)
{
    if ((MUST_WIN_EPOCH_SYNC)syncKind == MUST_WIN_EPOCH_NONE)
        return GTI_ANALYSIS_SUCCESS;

    // find window
    Win* info = getHandleInfo(pId, win);
    if (!info)
        return GTI_ANALYSIS_SUCCESS;

    // The fence epoch was not opened and we need to remove the locationId
    if (info->myEpoch.mySync & MUST_WIN_EPOCH_POTENTIAL_FENCE) {
        info->myEpoch.mySync &= ~MUST_WIN_EPOCH_POTENTIAL_FENCE;
        int key = info->myEpoch.getEpochLIdKey((MUST_WIN_EPOCH_SYNC)syncKind, targetRank);
        info->myEpoch.myEpochStartLIds.erase(key);
    }

    int key = info->myEpoch.getEpochLIdKey((MUST_WIN_EPOCH_SYNC)syncKind, targetRank);
    info->myEpoch.myEpochStartLIds[key] = lId;
    info->myEpoch.mySync |= (MUST_WIN_EPOCH_SYNC)syncKind;
    info->myEpoch.myAssertion = assert;

    return GTI_ANALYSIS_SUCCESS;
};

//=============================
// startEpochGroup
//=============================
GTI_ANALYSIS_RETURN WinTrack::startEpochGroup(
    MustParallelId pId,
    MustLocationId lId,
    MustGroupType group,
    int syncKind,
    int assert,
    MustWinType win)
{
    if ((MUST_WIN_EPOCH_SYNC)syncKind == MUST_WIN_EPOCH_NONE)
        return GTI_ANALYSIS_SUCCESS;

    // find window
    Win* winInfo = getHandleInfo(pId, win);
    if (!winInfo)
        return GTI_ANALYSIS_SUCCESS;

    // find group
    I_Group* groupInfo = myGTrack->getGroup(pId, group);
    if (!groupInfo)
        return GTI_ANALYSIS_SUCCESS;

    // The fence epoch was not opened and we need to remove the locationId
    if (winInfo->myEpoch.mySync & MUST_WIN_EPOCH_POTENTIAL_FENCE) {
        winInfo->myEpoch.mySync &= ~MUST_WIN_EPOCH_POTENTIAL_FENCE;
        int key = winInfo->myEpoch.getEpochLIdKey(
            (MUST_WIN_EPOCH_SYNC)syncKind,
            -MUST_WIN_EPOCH_POTENTIAL_FENCE);
        winInfo->myEpoch.myEpochStartLIds.erase(key);
    }

    int grpSize = groupInfo->getGroup()->getSize();
    int worldRank;
    int winRank;
    for (int i = 0; i < grpSize; i++) {
        // translate group rank to window rank
        groupInfo->getGroup()->translate(i, &worldRank);
        if (winInfo->getComm()->getGroup()->containsWorldRank(worldRank, &winRank))
            winInfo->myEpoch.myWinStartRanks.insert(winRank);
    }

    int key = winInfo->myEpoch.getEpochLIdKey((MUST_WIN_EPOCH_SYNC)syncKind, -syncKind);
    winInfo->myEpoch.myEpochStartLIds[key] = lId;
    winInfo->myEpoch.mySync |= (MUST_WIN_EPOCH_SYNC)syncKind;
    winInfo->myEpoch.myAssertion = assert;

    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// endEpoch
//=============================
GTI_ANALYSIS_RETURN
WinTrack::endEpoch(MustParallelId pId, int targetRank, int syncKind, MustWinType win)
{
    // find window
    Win* info = getHandleInfo(pId, win);
    if (!info)
        return GTI_ANALYSIS_SUCCESS;

    int key = info->myEpoch.getEpochLIdKey((MUST_WIN_EPOCH_SYNC)syncKind, targetRank);
    info->myEpoch.myEpochStartLIds.erase(key);

    // Return if MPI_Win_unlock was called and there are still locked ranks
    if ((MUST_WIN_EPOCH_SYNC)syncKind & MUST_WIN_EPOCH_LOCK &&
        info->myEpoch.myEpochStartLIds.lower_bound(0) != info->myEpoch.myEpochStartLIds.end()) {
        return GTI_ANALYSIS_SUCCESS;
    }

    if ((MUST_WIN_EPOCH_SYNC)syncKind & MUST_WIN_EPOCH_POST) {
        info->myEpoch.myWinStartRanks.clear();
    }

    info->myEpoch.mySync &= ~(MUST_WIN_EPOCH_SYNC)syncKind;

    // Reset information if no epoch is active
    if (info->myEpoch.mySync == MUST_WIN_EPOCH_NONE) {
        info->myEpoch.myAssertion = 0;
        info->myEpoch.mylastRMALId = MUST_INVALID_LOCATION_ID;
    }
    return GTI_ANALYSIS_SUCCESS;
};

//=============================
// addRMACallInEpoch
//=============================
GTI_ANALYSIS_RETURN
WinTrack::addRMACallInEpoch(MustParallelId pId, MustLocationId lId, int targetRank, MustWinType win)
{
    // A communication with MPI_PROC_NULL has no effect.
    if (myConsts->isProcNull(targetRank))
        return GTI_ANALYSIS_SUCCESS;

    // find window
    Win* info = getHandleInfo(pId, win);
    if (!info)
        return GTI_ANALYSIS_SUCCESS;

    info->myEpoch.mylastRMALId = lId;

    // A fence access epoch is opened when an RMA call is encountered after an MPI_Win_fence call.
    if (info->myEpoch.mySync & MUST_WIN_EPOCH_POTENTIAL_FENCE) {
        info->myEpoch.mySync &= ~MUST_WIN_EPOCH_POTENTIAL_FENCE;
        info->myEpoch.mySync |= MUST_WIN_EPOCH_FENCE;
    }

    return GTI_ANALYSIS_SUCCESS;
};

//=============================
// Destructor
//=============================
WinTrack::~WinTrack()
{
    // Notify HandleInfoBase of ongoing shutdown
    HandleInfoBase::disableFreeForwardingAcross();
    myDTrack->notifyOfShutdown();
    myCTrack->notifyOfShutdown();
}
/*EOF*/
