/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file WinChecks.cpp
 *       @see WinChecks.
 *
 *  @date 14.11.2024
 *  @author Cornelius Pätzold
 */

#include "GtiMacros.h"
#include "MustEnums.h"
#include "PrefixedOstream.hpp"
#include "WinChecks.h"
#include "Win.h"
#include "MustDefines.h"
#include "StridedBlock.h"

using namespace must;

mGET_INSTANCE_FUNCTION(WinChecks)
mFREE_INSTANCE_FUNCTION(WinChecks)
mPNMPI_REGISTRATIONPOINT_FUNCTION(WinChecks)

//=============================
// Constructor
//=============================
WinChecks::WinChecks(const char* instanceName)
    : gti::ModuleBase<WinChecks, I_WinChecks>(instanceName)
{
    // create sub modules
    std::vector<I_Module*> subModInstances;
    subModInstances = createSubModuleInstances();

    // handle sub modules
#define NUM_SUBMODULES 6
    if (subModInstances.size() < NUM_SUBMODULES) {
        must::cerr << "Module has not enough sub modules, check its analysis specification! ("
                   << __FILE__ << "@" << __LINE__ << ")" << std::endl;
        assert(0);
    }
    if (subModInstances.size() > NUM_SUBMODULES) {
        for (std::vector<I_Module*>::size_type i = NUM_SUBMODULES; i < subModInstances.size(); i++)
            destroySubModuleInstance(subModInstances[i]);
    }

    myPIdMod = (I_ParallelIdAnalysis*)subModInstances[0];
    myConsts = (I_BaseConstants*)subModInstances[1];
    myLogger = (I_CreateMessage*)subModInstances[2];
    myWinMod = (I_WinTrack*)subModInstances[3];
    myGroupMod = (I_GroupTrack*)subModInstances[4];
    myArgMod = (I_ArgumentAnalysis*)subModInstances[5];
    // Initialize module data
    // Nothing to do
}

//=============================
// Destructor
//=============================
WinChecks::~WinChecks()
{
    if (myPIdMod)
        destroySubModuleInstance((I_Module*)myPIdMod);
    myPIdMod = NULL;

    if (myLogger)
        destroySubModuleInstance((I_Module*)myLogger);
    myLogger = NULL;

    if (myWinMod)
        destroySubModuleInstance((I_Module*)myWinMod);
    myWinMod = NULL;
}

GTI_ANALYSIS_RETURN
WinChecks::warningIfWindowsOverlapCreate(
    MustParallelId pId,
    MustLocationId lId,
    void* base,
    int size,
    int kind)
{
    MustAddressType _base = ((MUST_WIN_KIND)kind == MUST_WIN_DYNAMIC)
                                ? (MustAddressType)myConsts->getBottom()
                                : (MustAddressType)base;
    auto interval = memInterval(
        StridedBlock(_base, _base, true, 0, 1, (MustAddressType)size, 0),
        0,
        (MustRequestType)0,
        true,
        NULL,
        _base,
        0);

    for (auto& handle : myWinMod->getUserHandles()) {
        if (isOverlapped(myWinMod->getWin(pId, handle.second)->getMemIntervals(), interval)) {
            std::stringstream stream;
            std::list<std::pair<MustParallelId, MustLocationId>> refs;
            // refs.push_back(std::make_pair(pId, epoch->getStartLId(target_rank)));
            stream << "Overlapping windows.";
            myLogger->createMessage(
                MUST_WARNING_WIN_OVERLAP,
                pId,
                lId,
                MustWarningMessage,
                stream.str());
            return GTI_ANALYSIS_FAILURE;
        }
    }
    return GTI_ANALYSIS_SUCCESS;
}

GTI_ANALYSIS_RETURN
WinChecks::warningIfWindowsOverlap(MustParallelId pId, MustLocationId lId, MustWinType win)
{
    auto& memIntervals = myWinMod->getWin(pId, win)->getMemIntervals();

    // Get handles of all windows
    auto handles = myWinMod->getUserHandles();
    for (auto& handle : handles) {
        // Skip current window
        if (handle.second == win)
            continue;
        for (auto& mem : memIntervals) {
            if (isOverlapped(myWinMod->getWin(pId, handle.second)->getMemIntervals(), mem)) {
                std::stringstream stream;
                std::list<std::pair<MustParallelId, MustLocationId>> refs;
                // refs.push_back(std::make_pair(pId, epoch->getStartLId(target_rank)));
                stream << "Overlapping windows.";
                myLogger->createMessage(
                    MUST_WARNING_WIN_OVERLAP,
                    pId,
                    lId,
                    MustWarningMessage,
                    stream.str());
                return GTI_ANALYSIS_FAILURE;
            }
        }
    }
    return GTI_ANALYSIS_SUCCESS;
}

GTI_ANALYSIS_RETURN
WinChecks::errorIfWinAttachOverlap(MustParallelId pId, MustLocationId lId, MustWinType win)
{
    auto& memIntervals = myWinMod->getWin(pId, win)->getMemIntervals();

    MustMemIntervalListType::iterator iter, nextIter;
    MustAddressType posA, posB;
    if (isOverlapped(memIntervals, iter, nextIter, posA, posB, false)) {
        std::stringstream stream;
        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        stream << "Overlapping MPI_Win_attach.";
        myLogger->createMessage(
            MUST_ERROR_WIN_ATTACH_OVERLAP,
            pId,
            lId,
            MustErrorMessage,
            stream.str());
        return GTI_ANALYSIS_FAILURE;
    }
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// errorIfConflictingActiveEpoch
//=============================
// put before WinTrack:startEpoch
GTI_ANALYSIS_RETURN WinChecks::errorIfConflictingActiveEpoch(
    MustParallelId pId,
    MustLocationId lId,
    int syncKind,
    int target_rank,
    MustWinType win)
{
    I_WinEpoch* epoch = myWinMod->getWin(pId, win)->getEpoch();

    if ((epoch->getSync() & ~MUST_WIN_EPOCH_POTENTIAL_FENCE) == MUST_WIN_EPOCH_NONE) {
        // There currently is no open epoch, hence we do not have any conflict
        return GTI_ANALYSIS_SUCCESS;
    }

    if (((MUST_WIN_EPOCH_SYNC)syncKind & MUST_WIN_EPOCH_POTENTIAL_FENCE) &&
        (epoch->getSync() & MUST_WIN_EPOCH_FENCE)) {
        return GTI_ANALYSIS_SUCCESS;
    }

    if ((epoch->getSync() & (MUST_WIN_EPOCH_SYNC)syncKind) == 0) {
        // destinguish between exposure and access epochs
        if ((epoch->getSync() & MUST_WIN_ACCESS_EPOCH_MASK) &&
            ((MUST_WIN_EPOCH_SYNC)syncKind & MUST_WIN_ACCESS_EPOCH_MASK)) {
            std::stringstream stream;
            stream << "Mixed two different window synchronization mechanisms.";
            std::list<std::pair<MustParallelId, MustLocationId>> refs;
            for (auto pair : epoch->getEpochLIdMap()) {
                refs.push_back(std::make_pair(pId, pair.second));
            }
            myLogger->createMessage(
                MUST_ERROR_WIN_EPOCH,
                pId,
                lId,
                MustErrorMessage,
                stream.str(),
                refs);
        }
        if ((epoch->getSync() & MUST_WIN_EXPOSURE_EPOCH_MASK) &&
            ((MUST_WIN_EPOCH_SYNC)syncKind & MUST_WIN_EXPOSURE_EPOCH_MASK)) {
            std::stringstream stream;
            stream << "Mixed two different window synchronization mechanisms for exposure epochs.";
            std::list<std::pair<MustParallelId, MustLocationId>> refs;
            for (auto pair : epoch->getEpochLIdMap()) {
                refs.push_back(std::make_pair(pId, pair.second));
            }
            myLogger->createMessage(
                MUST_ERROR_WIN_EPOCH,
                pId,
                lId,
                MustErrorMessage,
                stream.str(),
                refs);
        }
        return GTI_ANALYSIS_FAILURE;
    } else {
        // Synchronization was already used for this active epoch
        std::stringstream stream;
        switch ((MUST_WIN_EPOCH_SYNC)syncKind) {
        case MUST_WIN_EPOCH_LOCK:
            // Locking different ranks is allowed
            if (epoch->getEpochLId(target_rank) == MUST_INVALID_LOCATION_ID)
                return GTI_ANALYSIS_SUCCESS;
            stream << "Rank " << myPIdMod->getInfoForId(pId).rank << " locked the window of rank "
                   << target_rank << " a second time before unlocking it first.";
            break;
        case MUST_WIN_EPOCH_LOCK_ALL:
            stream << "Rank " << myPIdMod->getInfoForId(pId).rank
                   << " called MPI_Win_lock_all a second time before unlocking the window first.";
            break;
        case MUST_WIN_EPOCH_POST:
            stream << "Rank " << myPIdMod->getInfoForId(pId).rank
                   << " called MPI_Win_post a second time before calling MPI_Win_complete.";
            break;
            break;
        case MUST_WIN_EPOCH_START:
            stream << "Rank " << myPIdMod->getInfoForId(pId).rank
                   << " called MPI_Win_start a second time before calling MPI_Win_complete.";
            break;
        case MUST_WIN_EPOCH_FENCE:
        case MUST_WIN_EPOCH_POTENTIAL_FENCE:
        case MUST_WIN_EPOCH_NONE:
        default:
            return GTI_ANALYSIS_SUCCESS;
            break;
        }
        int key = epoch->getEpochLIdKey((MUST_WIN_EPOCH_SYNC)syncKind, target_rank);
        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        refs.push_back(std::make_pair(pId, epoch->getEpochLId(key)));
        myLogger
            ->createMessage(MUST_ERROR_WIN_LOCK, pId, lId, MustErrorMessage, stream.str(), refs);

        return GTI_ANALYSIS_FAILURE;
    }

    return GTI_ANALYSIS_SUCCESS;
};

//=============================
// errorIfRMACallOutsideEpoch
//=============================
GTI_ANALYSIS_RETURN WinChecks::errorIfRMACallOutsideEpoch(
    MustParallelId pId,
    MustLocationId lId,
    int target_rank,
    MustWinType win)
{
    // A communication with MPI_PROC_NULL has no effect.
    if (myConsts->isProcNull(target_rank))
        return GTI_ANALYSIS_SUCCESS;

    I_WinEpoch* epoch = myWinMod->getWin(pId, win)->getEpoch();

    if (epoch->getSync() == MUST_WIN_EPOCH_NONE) {
        std::stringstream stream;
        stream << "RMA communication call outside of an active access epoch.";
        myLogger->createMessage(MUST_ERROR_WIN_EPOCH, pId, lId, MustErrorMessage, stream.str());
        return GTI_ANALYSIS_SUCCESS;
    }

    if (epoch->getSync() & MUST_WIN_EPOCH_LOCK &&
        epoch->getEpochLId(target_rank) == MUST_INVALID_LOCATION_ID) {
        std::stringstream stream;
        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        stream << "RMA communication call outside of an active MPI_Win_lock access epoch. ";
        stream << "The references list the start locations of the detected active MPI_Win_lock "
                  "access epochs";
        auto it = epoch->getEpochLIdMap().lower_bound(0);
        while (it != epoch->getEpochLIdMap().end()) {
            refs.push_back(std::make_pair(pId, it->second));
            ++it;
        }
        myLogger
            ->createMessage(MUST_ERROR_WIN_EPOCH, pId, lId, MustErrorMessage, stream.str(), refs);
    } else if (
        epoch->getSync() & MUST_WIN_EPOCH_START &&
        epoch->getWinStartRanks().count(target_rank) == 0) {
        std::stringstream stream;
        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        stream << "RMA communication call outside of an active MPI_Win_start access epoch. ";
        stream << "Reference 1 is the start location of the detected active MPI_Win_start "
                  "access epoch ";
        int key = epoch->getEpochLIdKey(MUST_WIN_EPOCH_START, -MUST_WIN_EPOCH_START);
        refs.push_back(std::make_pair(pId, epoch->getEpochLId(key)));
        myLogger
            ->createMessage(MUST_ERROR_WIN_EPOCH, pId, lId, MustErrorMessage, stream.str(), refs);
    }
    return GTI_ANALYSIS_SUCCESS;
};

//=============================
// errorIfFenceAssertNoPrecede
//=============================
GTI_ANALYSIS_RETURN
WinChecks::errorIfFenceAssertNoPrecede(
    MustParallelId pId,
    MustLocationId lId,
    int assert,
    MustWinType win)
{
    I_WinEpoch* epoch = myWinMod->getWin(pId, win)->getEpoch();
    // Check if the epoch that is to be closed is a fence epoch and that the new/opening potential
    // fence epoch has the MPI_MODE_NOPRECEDE assertion.
    // If the to be closed epoch has the MUST_WIN_EPOCH_FENCE sync type then an RMA call was
    // encountered during that epoch, otherwise it would have the MUST_WIN_EPOCH_POTENTIAL_FENCE
    // sync type.
    if (epoch->getSync() == MUST_WIN_EPOCH_FENCE && assert & myConsts->getModeNoprecede()) {
        std::stringstream stream;
        stream << "MPI_MODE_NOPRECEDE assertion of MPI_Win_fence violated.";
        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        refs.push_back(std::make_pair(pId, epoch->getLastRMALId()));
        int key = epoch->getEpochLIdKey(MUST_WIN_EPOCH_FENCE, -MUST_WIN_EPOCH_FENCE);
        refs.push_back(std::make_pair(pId, epoch->getEpochLId(key)));
        myLogger->createMessage(
            MUST_ERROR_WIN_FENCE_ASSERT,
            pId,
            lId,
            MustErrorMessage,
            stream.str(),
            refs);
        return GTI_ANALYSIS_FAILURE;
    }
    return GTI_ANALYSIS_SUCCESS;
}

GTI_ANALYSIS_RETURN
WinChecks::errorIfWinNotKnown(MustParallelId pId, MustLocationId lId, int aId, MustWinType win)
{
    I_Win* info = myWinMod->getWin(pId, win);

    if (info == NULL) {
        std::stringstream stream;
        stream << "Argument " << myArgMod->getIndex(aId) << " (" << myArgMod->getArgName(aId)
               << ") is an unknown window where a valid window was expected.";

        myLogger->createMessage(MUST_ERROR_WIN_UNKNOWN, pId, lId, MustErrorMessage, stream.str());
        return GTI_ANALYSIS_FAILURE;
    }

    return GTI_ANALYSIS_SUCCESS;
};

GTI_ANALYSIS_RETURN
WinChecks::errorIfWinNull(MustParallelId pId, MustLocationId lId, int aId, MustWinType win)
{
    I_Win* info = myWinMod->getWin(pId, win);

    if (info && info->isNull()) {
        std::stringstream stream;
        stream << "Argument " << myArgMod->getIndex(aId) << " (" << myArgMod->getArgName(aId)
               << ") is MPI_WIN_NULL where a valid window was expected.";

        myLogger->createMessage(MUST_ERROR_WIN_NULL, pId, lId, MustErrorMessage, stream.str());
        return GTI_ANALYSIS_FAILURE;
    }

    return GTI_ANALYSIS_SUCCESS;
};

GTI_ANALYSIS_RETURN
WinChecks::errorIfRankGreaterEqualWinSize(
    MustParallelId pId,
    MustLocationId lId,
    int aId,
    int value,
    MustWinType win)
{
    // get communicator size
    int commSize = 0;
    I_Win* info = myWinMod->getWin(pId, win);

    if (info != NULL && !info->isNull()) {
        commSize = info->getComm()->getGroup()->getSize();
    } else {
        return GTI_ANALYSIS_FAILURE;
    }

    // check value
    if (value >= commSize) {
        std::stringstream stream;
        stream << "Argument " << myArgMod->getIndex(aId) << " (" << myArgMod->getArgName(aId)
               << ") specifies a rank that is ";

        if (value == commSize)
            stream << "equal to ";
        else
            stream << "greater than ";

        stream << " the size of the given window's communicator, while the value must be lower "
                  "than the "
                  "size of the window's communicator. "
               << "(" << myArgMod->getArgName(aId) << "=" << value
               << ", window's communicator size:" << commSize << ")!";

        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        stream << "(Information on window: ";
        info->printInfo(stream, &refs);
        stream << ")";

        myLogger->createMessage(
            MUST_ERROR_INTEGER_GREATER_EQUAL_COMM_SIZE,
            pId,
            lId,
            MustErrorMessage,
            stream.str(),
            refs);

        return GTI_ANALYSIS_FAILURE;
    }
    return GTI_ANALYSIS_SUCCESS;
};

//=============================
// errorIfFenceAssertNoSucceed
//=============================
GTI_ANALYSIS_RETURN
WinChecks::errorIfFenceAssertNoSucceed(
    MustParallelId pId,
    MustLocationId lId,
    int target_rank,
    MustWinType win)
{
    // A communication with MPI_PROC_NULL has no effect.
    if (myConsts->isProcNull(target_rank))
        return GTI_ANALYSIS_SUCCESS;

    I_WinEpoch* epoch = myWinMod->getWin(pId, win)->getEpoch();
    // Check if the current epoch has the MUST_WIN_EPOCH_FENCE sync type and that it
    // has the MPI_MODE_NOSUCCEED assertion. If an RMA call was encoutered during the
    // current epoch then that epoch has to be of the sync type MUST_WIN_EPOCH_FENCE,
    // otherwise it would be MUST_WIN_EPOCH_POTENTIAL_FENCE.

    if (epoch->getSync() == MUST_WIN_EPOCH_FENCE &&
        epoch->getAssertion() & myConsts->getModeNosucceed() &&
        epoch->getLastRMALId() != MUST_INVALID_LOCATION_ID) {
        std::stringstream stream;
        stream << "MPI_MODE_NOSUCCEED assertion of MPI_Win_fence violated.";
        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        refs.push_back(std::make_pair(pId, epoch->getLastRMALId()));
        int key = epoch->getEpochLIdKey(MUST_WIN_EPOCH_FENCE, target_rank);
        refs.push_back(std::make_pair(pId, epoch->getEpochLId(key)));
        myLogger->createMessage(
            MUST_ERROR_WIN_FENCE_ASSERT,
            pId,
            lId,
            MustErrorMessage,
            stream.str(),
            refs);
        return GTI_ANALYSIS_FAILURE;
    }

    return GTI_ANALYSIS_SUCCESS;
};

GTI_ANALYSIS_RETURN
WinChecks::errorIfRankInGroupNotInWin(
    MustParallelId pId,
    MustLocationId lId,
    int aId,
    MustGroupType group,
    MustWinType win)
{
    I_Win* winInfo = myWinMod->getWin(pId, win);
    I_Group* groupInfo = myGroupMod->getGroup(pId, group);
    if (groupInfo == NULL || groupInfo->isNull()) {
        return GTI_ANALYSIS_FAILURE;
    }

    int winRank, worldRank;
    std::stringstream rankStream;
    bool foundRank = false;
    int num = 0;
    for (int groupRank = 0; groupRank < groupInfo->getGroup()->getSize(); groupRank++) {
        groupInfo->getGroup()->translate(groupRank, &worldRank);
        // Check if winRank got changed by containsWorldRank(), otherwise it is not in the window
        // group
        if (!winInfo->getComm()->getGroup()->containsWorldRank(worldRank, &winRank)) {
            rankStream << "group rank: " << groupRank << " (world rank: " << worldRank << "), ";
            foundRank = true;
            if (++num >= MUST_MAX_NUM_RESOURCES)
                break;
        }
    }
    if (foundRank) {
        std::stringstream stream;
        stream << "Argument " << myArgMod->getIndex(aId) << " (" << myArgMod->getArgName(aId)
               << ") specifies at least one rank that is not in the window's communicator: ";
        stream << rankStream.str();
        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        stream << "(Information on group: ";
        groupInfo->printInfo(stream, &refs);
        stream << ")";
        stream << "(Information on window: ";
        winInfo->printInfo(stream, &refs);
        stream << ")";

        myLogger->createMessage(
            MUST_ERROR_GROUP_NOT_PART_OF_WIN,
            pId,
            lId,
            MustErrorMessage,
            stream.str(),
            refs);

        return GTI_ANALYSIS_FAILURE;
    }
    return GTI_ANALYSIS_SUCCESS;
};

//=============================
// errorIfOpenEpochOnWinFree
//=============================
GTI_ANALYSIS_RETURN
WinChecks::errorIfOpenEpochOnWinFree(MustParallelId pId, MustLocationId lId, MustWinType win)
{
    I_WinEpoch* epoch = myWinMod->getWin(pId, win)->getEpoch();
    if ((epoch->getSync() != MUST_WIN_EPOCH_NONE &&
         epoch->getSync() != MUST_WIN_EPOCH_POTENTIAL_FENCE)) {
        std::stringstream stream;
        stream << "Rank " << myPIdMod->getInfoForId(pId).rank
               << " called MPI_Win_free while still having an open access epoch.";
        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        for (auto pair : epoch->getEpochLIdMap()) {
            refs.push_back(std::make_pair(pId, pair.second));
        }
        myLogger
            ->createMessage(MUST_ERROR_WIN_EPOCH, pId, lId, MustErrorMessage, stream.str(), refs);

        return GTI_ANALYSIS_FAILURE;
    }
    return GTI_ANALYSIS_SUCCESS;
};

//=============================
// errorIfUnlockingWrongRank
//=============================
GTI_ANALYSIS_RETURN
WinChecks::errorIfUnlockingWrongRank(
    MustParallelId pId,
    MustLocationId lId,
    int target_rank,
    MustWinType win)
{
    I_WinEpoch* epoch = myWinMod->getWin(pId, win)->getEpoch();

    if (epoch->getEpochLId(target_rank) == MUST_INVALID_LOCATION_ID) {
        std::stringstream stream;
        stream << "MPI_Win_unlock attempted to unlock the window of rank " << target_rank
               << " but it was never locked.";
        myLogger->createMessage(MUST_ERROR_WIN_LOCK, pId, lId, MustErrorMessage, stream.str());
        return GTI_ANALYSIS_FAILURE;
    }

    return GTI_ANALYSIS_SUCCESS;
}
