/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file DDGroupMatch.h
 *       @see must::DDGroupMatch.
 *
 *  @date 20.01.2012
 *  @author Tobias Hilbrich, Mathias Korepkat, Joachim Protze, Fabian Haensel
 */

#include "GtiEnums.h"
#include "ModuleBase.h"
#include "MustTypes.h"
#include "I_ParallelIdAnalysis.h"
#include "I_LocationAnalysis.h"
#include "I_CreateMessage.h"
#include "CompletionTree.h"
#include "I_WinTrack.h"
#include "I_GroupTrack.h"
#include "I_CommTrack.h"
#include "I_RequestTrack.h"
#include "I_BaseConstants.h"
#include "I_FloodControl.h"
#include "I_Profiler.h"
#include "DistributedDeadlockApi.h"
#include "DGroupOp.h"

#include "I_DGroupMatch.h"

#ifndef DGROUPMATCH_H
#define DGROUPMATCH_H

using namespace gti;

namespace must
{
/**
 * Enumeration of the kinds of handles.
 */
enum MUST_GROUPMATCH_HANDLE { MUST_GROUPMATCH_WIN_HANDLE = 0, MUST_GROUPMATCH_COMM_HANDLE };

/**
 * Enumeration of the kinds of calls.
 */
enum MUST_GROUPMATCH_CALLID {
    MUST_GROUPMATCH_START_CALLID = 0,
    MUST_GROUPMATCH_POST_CALLID,
    MUST_GROUPMATCH_WAIT_CALLID,
    MUST_GROUPMATCH_COMPLETE_CALLID
};

/**
 * Forward declaration of DGroupOp, we have cyclic
 * dependencies between the matcher and the op.
 */
class DGroupOp;

/**
 * Maps dests/sources (converted to respective world rank) to list of
 * outstanding sends/recvs with this dest/source.
 */
// typedef std::map<int, std::list<DGroupOp*>> RankOutstandings;
typedef std::list<DGroupOp*> RankOutstandings;

/**
 * The send and receive tables for a rank (in comm world) and a certain
 * communicator. This contains all outstanding (unmatched)
 * sends and receives.
 */
class ProcessTable
{
  public:
    RankOutstandings
        sends; /**< Maps destination of outstanding send to list of sends with this destination.*/
    RankOutstandings
        recvs; /**< Maps source of outstanding recv to list of recvs with this source.*/
};

struct PassInfo {
    std::vector<int> ranks;
    std::vector<MustLTimeStamp> timestamps;
    std::vector<int> sendcounts;
    std::vector<MustDatatypeType> types;
};

/**
 * Info structure that holds all we need to pass information on a send operation
 * to a sister node.
 */
// class PassSendInfo
// {
//   public:
//     MustParallelId pId;
//     MustLocationId lId;
//     std::set<int> destinations;
//     MustWinType win;
//     int targetPlace;
//     int callId;
//     I_Destructable* hanldeInfo;
// };

/**
 * Implementation for I_DGroupMatch.
 */
class DGroupMatch : public gti::ModuleBase<DGroupMatch, I_DGroupMatch>
{
    // Make friends with DGroupOp
    friend class DGroupOp;

  public:
    /**
     * Constructor.
     * @param instanceName name of this module instance.
     */
    DGroupMatch(const char* instanceName);

    /**
     * Destructor.
     */
    virtual ~DGroupMatch(void);

    /**
     * @see I_DGroupMatch::init
     */
    GTI_ANALYSIS_RETURN init(MustParallelId pId);

    /**
     * @see I_DGroupMatch::post
     */
    GTI_ANALYSIS_RETURN
    post(MustParallelId pId, MustLocationId lId, MustGroupType group, MustWinType win, int callId);

    /**
     * @see I_DGroupMatch::start
     */
    GTI_ANALYSIS_RETURN
    start(MustParallelId pId, MustLocationId lId, MustGroupType group, MustWinType win, int callId);

    /**
     * @see I_DGroupMatch::complete
     */
    GTI_ANALYSIS_RETURN
    complete(MustParallelId pId, MustLocationId lId, MustWinType win, int callId);

    /**
     * @see I_DGroupMatch::wait
     */
    GTI_ANALYSIS_RETURN
    wait(MustParallelId pId, MustLocationId lId, MustWinType win, int callId);

    GTI_ANALYSIS_RETURN recvPassToGroupNoTransfer(
        MustParallelId pId,
        MustLocationId lId,
        const int* worldRanks,
        const MustLTimeStamp* timestamps,
        int worldRanksCount,
        int handleType,
        MustRemoteIdType remoteHandleId,
        int callId,
        bool isCollective);

    GTI_ANALYSIS_RETURN recvPassToGroupN(
        MustParallelId pId,
        MustLocationId lId,
        MustDatatypeType type,
        const int sendCount,
        const int* worldRanks,
        const MustLTimeStamp* timestamps,
        int count,
        int handleType,
        MustRemoteIdType remoteHandleId,
        int callId,
        bool isCollective);

    GTI_ANALYSIS_RETURN recvPassToGroupTypes(
        MustParallelId pId,
        MustLocationId lId,
        MustDatatypeType* types,
        const int* sendCounts,
        const int* worldRanks,
        const MustLTimeStamp* timestamps,
        int count,
        int handleType,
        MustRemoteIdType remoteHandleId,
        int callId,
        bool isCollective);

    GTI_ANALYSIS_RETURN recvPassToGroupCounts(
        MustParallelId pId,
        MustLocationId lId,
        MustDatatypeType type,
        const int* sendCounts,
        const int* worldRanks,
        const MustLTimeStamp* timestamps,
        int count,
        int handleType,
        MustRemoteIdType remoteHandleId,
        int callId,
        bool isCollective);

    /**
     * @see I_DGroupMatch::registerListener
     */
    void registerListener(I_DGroupListener* listener);

    I_CreateMessage* myLogger;

  protected:
    ////Own place ID
    int myPlaceId;

    ////Child modules
    I_ParallelIdAnalysis* myPIdMod;
    I_LocationAnalysis* myLIdMod;
    I_BaseConstants* myConsts;
    I_WinTrack* myWTrack;
    I_GroupTrack* myGTrack;
    I_CommTrack* myCTrack;
    I_DatatypeTrack* myDTrack;
    I_FloodControl* myFloodControl;
    I_Profiler* myProfiler;

    ////Listener
    I_DGroupListener* myListener;

    ////Matching

    typedef std::map<I_Destructable*, ProcessTable> ProcessQueues; /**< handle to table.*/
    typedef std::map<int, ProcessQueues> QT;                       /**< World rank to Queues.*/
    QT myQs;                                                       /**< My Matching structure .*/
    uint64_t myQSize;
    uint64_t myMaxQSize;

    std::map<MustWinType, std::map<int, std::vector<int>>> myPostRanks;
    std::map<MustWinType, std::map<int, std::vector<int>>> myStartRanks;

    // std::map<std::pair<int, MustLTimeStamp>, PassSendInfo>
    //     mySendsToTransfer; /**< Stores send operations that we still need to pass (we may not
    //     pass
    //                           them immediately to sister nodes if DWS tells us they are not
    //                           active yet).*/

    ////Function pointers for wrap-across
    passToGroupNoTransferP myPassToGroupNoTransferFunction;
    passToGroupNP myPassToGroupNFunction;
    passToGroupTypesP myPassToGroupTypesFunction;
    passToGroupCountsP myPassToGroupCountsFunction;

    /**
     * Helper for getting the window information
     *
     * @param pId context.
     * @param win of the send/recv.
     * @param pOutWin persistent info of win.
     * @return true if successfull, false otherwise (persistent handles will already be erased in
     * that case).
     */
    bool getWinInfo(MustParallelId pId, MustWinType win, I_WinPersistent** pOutWin);

    /**
     * Helper for getting the group information
     *
     * @param pId context.
     * @param group group of the send/recv.
     * @param pOutGroup persistent group.
     * @return true if successfull, false otherwise (persistent handles will already be erased in
     * that case).
     */
    bool getGroupInfo(MustParallelId pId, MustGroupType group, I_GroupPersistent** pOutGroup);

    bool getTypeInfo(MustParallelId pId, MustDatatypeType type, I_DatatypePersistent** pOutType);

    /**
     * Prints all information in the queues.
     * (For Debugging)
     */
    void printQs(void);

    /**
     * Searches the queues for a matching recv or send that matches the given send or recv info,
     * respectively.
     * @param op to find a matching for.
     * @return true if a match was found
     */
    bool findMatchingOp(DGroupOp* op);

    /**
     * Adds an unmatched operation to the Q.
     *
     * @param op to add to the queue.
     */
    void addOutstandingOp(DGroupOp* op);

    bool getCommInfo(MustParallelId pId, MustCommType comm, I_CommPersistent** pOutComm);

    bool getPassedHandleInfo(
        MustParallelId pId,
        MustRemoteIdType remoteHandleId,
        MUST_GROUPMATCH_HANDLE handleType,
        I_Destructable** pOutInfo);

    GTI_ANALYSIS_RETURN nbrSendN(
        MustParallelId pId,
        MustLocationId lId,
        MustCommType comm,
        int callId,
        int sendcount,
        MustDatatypeType sendtype);

    GTI_ANALYSIS_RETURN nbrSendCounts(
        MustParallelId pId,
        MustLocationId lId,
        MustCommType comm,
        int callId,
        const int* sendcounts,
        MustDatatypeType sendtype);

    GTI_ANALYSIS_RETURN nbrSendTypes(
        MustParallelId pId,
        MustLocationId lId,
        MustCommType comm,
        int callId,
        const int* sendcounts,
        const MustDatatypeType* sendtypes);

    GTI_ANALYSIS_RETURN nbrRecvN(
        MustParallelId pId,
        MustLocationId lId,
        MustCommType comm,
        int callId,
        int recvcount,
        MustDatatypeType recvtype,
        bool hasRequest,
        MustRequestType request);

    GTI_ANALYSIS_RETURN nbrRecvCounts(
        MustParallelId pId,
        MustLocationId lId,
        MustCommType comm,
        int callId,
        const int* recvcounts,
        MustDatatypeType recvtype,
        bool hasRequest,
        MustRequestType request);

    GTI_ANALYSIS_RETURN nbrRecvTypes(
        MustParallelId pId,
        MustLocationId lId,
        MustCommType comm,
        int callId,
        const int* recvcounts,
        const MustDatatypeType* recvtypes,
        bool hasRequest,
        MustRequestType request);

    void notifyListenerOfNewOp(
        MustParallelId pId,
        MustLocationId lId,
        I_CommPersistent* cInfo,
        I_WinPersistent* wInfo,
        bool isSend,
        MustSendMode sendMode,
        std::vector<int> ranksToNotify,
        std::vector<MustLTimeStamp>& pOutTimestamps);

    template <typename Inserter>
    std::map<int, PassInfo> getPlacesAndNotifyListenerInternal(
        MustParallelId pId,
        MustLocationId lId,
        I_CommPersistent* cInfo,
        I_WinPersistent* wInfo,
        bool isSend,
        MustSendMode sendMode,
        const int* ranksToNotify,
        size_t size,
        Inserter insertExtraFields);

    std::map<int, PassInfo> getPlacesAndNotifyListener(
        MustParallelId pId,
        MustLocationId lId,
        I_CommPersistent* cInfo,
        I_WinPersistent* wInfo,
        bool isSend,
        MustSendMode sendMode,
        const int* ranksToNotify,
        size_t size);

    std::map<int, PassInfo> getPlacesAndNotifyListener(
        MustParallelId pId,
        MustLocationId lId,
        I_CommPersistent* cInfo,
        I_WinPersistent* wInfo,
        bool isSend,
        MustSendMode sendMode,
        const int* ranksToNotify,
        const int* sendCounts,
        size_t size);

    std::map<int, PassInfo> getPlacesAndNotifyListener(
        MustParallelId pId,
        MustLocationId lId,
        I_CommPersistent* cInfo,
        I_WinPersistent* wInfo,
        bool isSend,
        MustSendMode sendMode,
        const int* ranksToNotify,
        const int* sendCounts,
        const MustDatatypeType* types,
        size_t size);
};
} // namespace must

#endif /*DGROUPMATCH_H*/
