/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file TargetRaceVerifier.cpp
 *       @see must::TargetRaceVerifier.
 *
 *  @date 02.05.2025
 *  @author Simon Schwitanski
 */

#include "GtiMacros.h"
#include "TargetRaceVerifier.h"
#include "MustEnums.h"
#include "MustDefines.h"
#include "pnmpi/service.h"
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include "PrefixedOstream.hpp"

#include <sstream>
#include <fstream>

using namespace must;

mGET_INSTANCE_FUNCTION(TargetRaceVerifier)
mFREE_INSTANCE_FUNCTION(TargetRaceVerifier)
mPNMPI_REGISTRATIONPOINT_FUNCTION(TargetRaceVerifier)

//=============================
// Constructor.
//=============================
TargetRaceVerifier::TargetRaceVerifier(const char* instanceName)
    : ModuleBase<TargetRaceVerifier, I_TargetRaceVerifier>(instanceName)
{
    // create sub modules
    std::vector<I_Module*> subModInstances;
    subModInstances = createSubModuleInstances();

    // handle sub modules
#define NUM_SUBMODULES 2
    if (subModInstances.size() < NUM_SUBMODULES) {
        std::cerr << "Module has not enough sub modules, check its analysis specification! ("
                  << __FILE__ << "@" << __LINE__ << ")" << std::endl;
        assert(0);
    }
    if (subModInstances.size() > NUM_SUBMODULES) {
        for (std::vector<I_Module*>::size_type i = NUM_SUBMODULES; i < subModInstances.size(); i++)
            destroySubModuleInstance(subModInstances[i]);
    }

    myPIdMod = (I_ParallelIdAnalysis*)subModInstances[0];
    myDatMod = (I_DatatypeTrack*)subModInstances[1];
}

//=============================
// isRace
//=============================
bool TargetRaceVerifier::isRace(
    MustParallelId pId1,
    MustRMAId rmaId1,
    MustParallelId pId2,
    MustRMAId rmaId2)
{
    auto rank1 = myPIdMod->getInfoForId(pId1).rank;
    auto rank2 = myPIdMod->getInfoForId(pId2).rank;

    RMAOpHistoryData op1;
    RMAOpHistoryData op2;
    if (!myRMAAnnotationHistory[rank1].find(rmaId1, op1)) {
        cout << "Error: Could not find RMA access history for operation " << rmaId1 << " from rank "
             << rank1 << std::endl;
    }
    if (!myRMAAnnotationHistory[rank2].find(rmaId2, op2)) {
        cout << "Error: Could not find RMA access history for operation " << rmaId2 << " from rank "
             << rank2 << std::endl;
    }

    // In order to be a safe pair of atomic accesses, the types should be
    // 1) both atomic
    // 2) have the same datatype
    // 3) the startAddress difference should be a multiple of the type size, otherwise we have
    // overlaping accesses because the types would overlap)
    if (op1.isAtomic && op2.isAtomic && op1.type == op2.type &&
        (op1.startAddr - op2.startAddr) % op1.typeSize == 0) {
        // atomic access, same atomic datatypes

        return false;
    }

    // If access windows of operations are disjoint, this access is safe
    if ((op1.startClock.size() > 0 && op2.endClock.size() > 0) &&
        (op1.endClock < op2.startClock || op2.endClock < op1.startClock)) {
        return false;
    }

    return true;
}

void TargetRaceVerifier::addHistoryData(
    int origin,
    MustRMAId rmaId,
    bool isAtomic,
    int epoch,
    RMADataTypeId type,
    size_t typeSize,
    MustAddressType startAddr,
    MustLocationId lIdStart,
    MustLocationId lIdEnd,
    Clock startClock,
    Clock endClock)
{
    RMAOpHistoryData op;
    op.callId = rmaId;
    op.isAtomic = isAtomic;
    op.epoch = epoch;
    op.type = type;
    op.typeSize = typeSize;
    op.startAddr = startAddr;
    op.lIdStart = lIdStart;
    op.lIdEnd = lIdEnd;
    op.startClock = startClock;
    op.endClock = endClock;

    myRMAAnnotationHistory[origin].add(op);
}

bool TargetRaceVerifier::getHistoryData(MustParallelId pId, MustRMAId rmaId, RMAOpHistoryData& op)
{
    auto rank = myPIdMod->getInfoForId(pId).rank;
    return myRMAAnnotationHistory[rank].find(rmaId, op);
}

//=============================
// Destructor.
//=============================
TargetRaceVerifier::~TargetRaceVerifier(void)
{
    if (myPIdMod)
        destroySubModuleInstance((I_Module*)myPIdMod);
    myPIdMod = NULL;

    if (myDatMod)
        destroySubModuleInstance((I_Module*)myDatMod);
    myDatMod = NULL;
}
