/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file TargetRaceVerifier.h
 *       @see must::TargetRaceVerifier.
 *
 *  @date 02.05.2025
 *  @author Simon Schwitanski
 */

#include "ModuleBase.h"
#include "I_ParallelIdAnalysis.h"
#include "I_BaseConstants.h"
#include "I_CreateMessage.h"
#include "I_DatatypeTrack.h"
#include "I_RequestTrack.h"
#include "I_WinTrack.h"
#include "I_TSan.h"
#include "I_TSanSyncClockRecorder.h"
#include "I_LocationAnalysis.h"
#include "I_RMASanitize.h"
#include "I_VectorClock.h"

#include "I_TargetRMAOp.h"
#include "I_TargetRaceVerifier.h"

#include <map>

#ifndef TargetRaceVerifier_H
#define TargetRaceVerifier_H

using namespace gti;

#define CIRCULAR_BUFFER_SIZE 100

// Circular buffer to store the most recent N = 100 remote RMA accesses from a rank
class CircularBuffer
{
  public:
    CircularBuffer() : data(), start(0), end(0), size(CIRCULAR_BUFFER_SIZE) {}
    void add(RMAOpHistoryData item)
    {
        data[end] = item;
        end = (end + 1) % size;
        if (start == end) { // if buffer is full, also move start
            start = (start + 1) % size;
        }
    };

    bool find(MustRMAId callId, RMAOpHistoryData& result)
    {
        size_t cur = start;
        while (cur != end) {
            if (data[cur].callId == callId) {
                result = data[cur];
                return true;
            }
            cur = (cur + 1) % size;
        }
        return false;
    };

  private:
    std::array<RMAOpHistoryData, CIRCULAR_BUFFER_SIZE> data;
    size_t start;
    size_t end;
    size_t size;
};

namespace must
{
/**
 * Correctness checks for memory operations at the target.
 */
class TargetRaceVerifier : public gti::ModuleBase<TargetRaceVerifier, I_TargetRaceVerifier>
{
  public:
    /**
     * Constructor.
     * @param instanceName name of this module instance.
     */
    TargetRaceVerifier(const char* instanceName);

    /**
     * Destructor.
     */
    virtual ~TargetRaceVerifier(void);

    bool getHistoryData(MustParallelId pId, MustRMAId rmaId, RMAOpHistoryData& op);

  protected:
    I_ParallelIdAnalysis* myPIdMod;
    I_DatatypeTrack* myDatMod;

  private:
    // store history of annotated RMA calls for race verification
    std::unordered_map<int, CircularBuffer> myRMAAnnotationHistory;

    void addHistoryData(
        int origin,
        MustRMAId rmaId,
        bool isAtomic,
        int epoch,
        RMADataTypeId type,
        size_t typeSize,
        MustAddressType startAddr,
        MustLocationId lIdStart,
        MustLocationId lIdEnd,
        Clock startClock,
        Clock endClock);

    bool isRace(MustParallelId pId1, MustRMAId rmaId1, MustParallelId pId2, MustRMAId rmaId2);

    /**
     * Translates a rank in a given communicator to the
     * corresponding rank in MPI_COMM_WORLD.
     *
     * @param comm communicator the rank belongs to
     * @param rank rank to convert
     * @return corresponding rank in MPI_COMM_WORLD
     */
    int translateRank(I_Comm* comm, int rank);

}; /*class TargetRaceVerifier.h*/
} /*namespace must*/

#endif /*TargetRaceVerifier.H_H*/
