/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file RMAValueChecks.cpp
 *       @see MUST::RMAValueChecks.
 *
 *  @date 15.01.2025
 *  @author Cornelius Pätzold
 */

#include "GtiMacros.h"
#include "RMAValueChecks.h"
#include "MustEnums.h"
#include "PrefixedOstream.hpp"

#include <sstream>

using namespace must;

mGET_INSTANCE_FUNCTION(RMAValueChecks)
mFREE_INSTANCE_FUNCTION(RMAValueChecks)
mPNMPI_REGISTRATIONPOINT_FUNCTION(RMAValueChecks)

//=============================
// Constructor
//=============================
RMAValueChecks::RMAValueChecks(const char* instanceName)
    : gti::ModuleBase<RMAValueChecks, I_RMAValueChecks>(instanceName)
{
    // create sub modules
    std::vector<I_Module*> subModInstances;
    subModInstances = createSubModuleInstances();

    // handle sub modules
#define NUM_MODULES 4
    if (subModInstances.size() < NUM_MODULES) {
        must::cerr << "Module has not enough sub modules, check its analysis specification! ("
                   << __FILE__ << "@" << __LINE__ << ")" << std::endl;
        assert(0);
    }
    if (subModInstances.size() > NUM_MODULES) {
        for (std::vector<I_Module*>::size_type i = NUM_MODULES; i < subModInstances.size(); i++)
            destroySubModuleInstance(subModInstances[i]);
    }

    myPIdMod = (I_ParallelIdAnalysis*)subModInstances[0];
    myLogger = (I_CreateMessage*)subModInstances[1];
    myArgMod = (I_ArgumentAnalysis*)subModInstances[2];
    myConstMod = (I_BaseConstants*)subModInstances[3];

    // Initialize module data
    // Nothing to do
}

//=============================
// Destructor
//=============================
RMAValueChecks::~RMAValueChecks()
{
    if (myPIdMod)
        destroySubModuleInstance((I_Module*)myPIdMod);
    myPIdMod = NULL;

    if (myLogger)
        destroySubModuleInstance((I_Module*)myLogger);
    myLogger = NULL;

    if (myArgMod)
        destroySubModuleInstance((I_Module*)myArgMod);
    myArgMod = NULL;

    if (myConstMod)
        destroySubModuleInstance((I_Module*)myConstMod);
    myConstMod = NULL;
}

//=============================
// init
//=============================
GTI_ANALYSIS_RETURN RMAValueChecks::init()
{
    if (!myConstMod)
        return GTI_ANALYSIS_FAILURE;

    myValidFenceAssertions = myConstMod->getModeNostore() | myConstMod->getModeNoput() |
                             myConstMod->getModeNoprecede() | myConstMod->getModeNosucceed();
    myValidPostAssertions =
        myConstMod->getModeNocheck() | myConstMod->getModeNostore() | myConstMod->getModeNoput();
    myValidStartAssertions = myConstMod->getModeNocheck();
    myValidLockAssertions = myConstMod->getModeNocheck();
    myValidLockTypes = myConstMod->getLockExclusive() | myConstMod->getLockShared();
    return GTI_ANALYSIS_SUCCESS;
}

GTI_ANALYSIS_RETURN RMAValueChecks::errorIfInvalidAssertion(
    MustParallelId pId,
    MustLocationId lId,
    int assert,
    int validAssertions)
{
    if ((assert & ~validAssertions) != 0) {
        std::stringstream stream;
        stream << "The provided assertion (" << assert << ") is invalid.";
        myLogger
            ->createMessage(MUST_ERROR_INVALID_ASSERTION, pId, lId, MustErrorMessage, stream.str());
    }
    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// errorIfInvalidFenceAssertion
//=============================
GTI_ANALYSIS_RETURN
RMAValueChecks::errorIfInvalidFenceAssertion(MustParallelId pId, MustLocationId lId, int assert)
{
    return errorIfInvalidAssertion(pId, lId, assert, myValidFenceAssertions);
}

//=============================
// errorIfInvalidPostAssertion
//=============================
GTI_ANALYSIS_RETURN
RMAValueChecks::errorIfInvalidPostAssertion(MustParallelId pId, MustLocationId lId, int assert)
{
    return errorIfInvalidAssertion(pId, lId, assert, myValidPostAssertions);
}

//=============================
// errorIfInvalidStartAssertion
//=============================
GTI_ANALYSIS_RETURN
RMAValueChecks::errorIfInvalidStartAssertion(MustParallelId pId, MustLocationId lId, int assert)
{
    return errorIfInvalidAssertion(pId, lId, assert, myValidStartAssertions);
}

//=============================
// errorIfInvalidLockAssertion
//=============================
GTI_ANALYSIS_RETURN
RMAValueChecks::errorIfInvalidLockAssertion(MustParallelId pId, MustLocationId lId, int assert)
{
    return errorIfInvalidAssertion(pId, lId, assert, myValidLockAssertions);
}

//=============================
// errorIfInvalidLockType
//=============================
GTI_ANALYSIS_RETURN
RMAValueChecks::errorIfInvalidLockType(MustParallelId pId, MustLocationId lId, int lock_type)
{
    if ((lock_type & ~myValidLockTypes) != 0) {
        std::stringstream stream;
        stream << "The provided lock type (" << lock_type << ") for MPI_Win_lock is invalid.";
        myLogger
            ->createMessage(MUST_ERROR_INVALID_ASSERTION, pId, lId, MustErrorMessage, stream.str());
        return GTI_ANALYSIS_FAILURE;
    }
    return GTI_ANALYSIS_SUCCESS;
}

/*EOF*/
