/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file QOpGroupCompletion.cpp
 *       @see must::QOpGroupCompletion.
 *
 *  @date 14.04.2025
 *  @author Corneius Pätzold
 */

#include "QOpGroupCompletion.h"
#include "DWaitState.h"
#include <string>
#include <fstream>
#include <string.h>

using namespace must;

//=============================
// QOpGroupCompletion -- Constructor
//=============================
QOpGroupCompletion::QOpGroupCompletion(
    DWaitState* dws,
    MustParallelId pId,
    MustLocationId lId,
    MustLTimeStamp ts,
    int count,
    MustLTimeStamp* timestamps,
    I_CommPersistent* comm)
    : QOp(dws, pId, lId, ts), QOpCompletion(dws, pId, lId, ts, count, timestamps, true),
      QOpCommunication(dws, pId, lId, ts, comm), myRequest(), myRequests(), myWaitsForAll(true),
      myNumCompleted(0), myMatchIndex(-1)
{
}

//=============================
// QOpGroupCompletion -- Destructor
//=============================
QOpGroupCompletion::~QOpGroupCompletion(void)
{
    // Propagate the erase for completing this op
    if (QOpCompletion::myRequests.size() == 0) {
        QOpCompletion::myRequest.nonBlockingOp->erase();
    } else {
        for (auto& req : QOpCompletion::myRequests) {
            req.nonBlockingOp->erase();
        }
        QOpCompletion::myRequests.clear();
    }
}

//=============================
// printAsDot
//=============================
std::string
QOpGroupCompletion::printAsDot(std::ofstream& out, std::string nodePrefix, std::string color)
{
    return QOpCompletion::printAsDot(out, nodePrefix, color);
}

//=============================
// printVariablesAsLabelString
//=============================
std::string QOpGroupCompletion::printVariablesAsLabelString(void)
{
    return QOpCompletion::printVariablesAsLabelString();
}

//=============================
// notifyActive
//=============================
void QOpGroupCompletion::notifyActive(void) { QOpCompletion::notifyActive(); }

//=============================
// blocks
//=============================
bool QOpGroupCompletion::blocks(void)
{
    // If this is a "non-blocking waitall" which later gets handled in an actual MPI_Wait etc. then
    // we do not block
    if (hasRequest())
        return false;
    return QOpCompletion::blocks();
}

//=============================
// isMatchedWithActiveOps
//=============================
bool QOpGroupCompletion::isMatchedWithActiveOps(void)
{
    checkAlive();
    // First check if we have updates for our completion
    notifyActive();
    return !(QOpCompletion::blocks());
}

//=============================
// needsToBeInTrace
//=============================
bool QOpGroupCompletion::needsToBeInTrace(void) { return QOpCompletion::needsToBeInTrace(); }

//=============================
// forwardWaitForInformation
//=============================
void QOpGroupCompletion::forwardWaitForInformation(
    std::map<I_Comm*, std::string>& commLabels,
    std::map<I_Win*, std::string>& winLabels)
{
    //==Do we block at all?
    if (!blocks())
        return;

    forwardThisOpsWaitForInformation(-1, commLabels, winLabels);
}

//=============================
// forwardThisOpsWaitForInformation
//=============================
void QOpGroupCompletion::forwardThisOpsWaitForInformation(
    int subIdToUse,
    std::map<I_Comm*, std::string>& commLabels,
    std::map<I_Win*, std::string>& winLabels)
{
    //==Get the function
    provideWaitForInfosSingleP fSingle = myState->getProvideWaitSingleFunction();

    //==How many sub nodes?
    int numSubs = 0;
    std::stringstream labelstream;
    std::vector<int> toRanks;
    if (QOpCompletion::myRequests.size() == 0) {
        if (QOpCompletion::myRequest.completed || !QOpCompletion::myRequest.nonBlockingOp ||
            QOpCompletion::myRequest.nonBlockingOp->isMatchedWithActiveOps())
            return;

        if (QOpCompletion::myRequest.nonBlockingOp->asOpCommunicationP2P()) {
            QOpCommunicationP2P* p2pOp =
                QOpCompletion::myRequest.nonBlockingOp->asOpCommunicationP2P();

            bool isWc = false;
            int sourceTarget = p2pOp->getSourceTarget(&isWc);

            if (isWc) {
                // This should never be a WC
                return;
            }

            numSubs = 1;
            labelstream << p2pOp->getCompletionEdgeLabel()
                        << p2pOp->getLabels(1, commLabels, winLabels).str() << std::endl;

            toRanks.push_back(sourceTarget);
        } else {
            // GroupCompletion currently only uses P2P communications.
        }
    } else {
        /**
         * @label  [REDUCTION]
         * We apply some WFG reductions here to avoid the use of too many irrelevant sub-nodes:
         * - For P2POps that don't use a wildcard we can skip a request if we already had a similar
         *    sub-node (i.e. one with the same target rank); Irrespective of it being an all/any
         * wait
         * - woldcard P2POps are currently not part of group completions.
         */
        std::set<int> usedTargetRanks;

        for (std::vector<RequestInfo>::size_type i = 0; i < QOpCompletion::myRequests.size(); i++) {
            // Is this a relevant request (unmatched, valid, ...)
            if (QOpCompletion::myRequests[i].completed ||
                !QOpCompletion::myRequests[i].nonBlockingOp ||
                QOpCompletion::myRequests[i].nonBlockingOp->isMatchedWithActiveOps())
                continue;

            if (QOpCompletion::myRequests[i].nonBlockingOp->asOpCommunicationP2P()) {
                ////HANDLING FOR P2P Ops with reduction mentioned above
                QOpCommunicationP2P* p2pOp =
                    QOpCompletion::myRequests[i].nonBlockingOp->asOpCommunicationP2P();
                assert(p2pOp);

                // Can we apply a WFG reduction for this request (does it adds relevant or redundant
                // WFG information)
                bool isWc = false;
                int sourceTarget = p2pOp->getSourceTarget(&isWc);

                if (isWc) {
                    // This should never be a WC
                    continue;
                } else {
                    if (usedTargetRanks.find(sourceTarget) != usedTargetRanks.end())
                        continue;
                    usedTargetRanks.insert(sourceTarget);
                }

                // Its valid and important, so it counts!
                numSubs++;
                labelstream << p2pOp->getCompletionEdgeLabel() << "[" << i << "], "
                            << p2pOp->getLabels(1, commLabels, winLabels).str() << std::endl;

                toRanks.push_back(sourceTarget);
            } else {
                // GroupCompletion currently only uses P2P communications.
            }
        }
    }

    //==Prepare
    std::vector<MustParallelId> labelPIds(numSubs, 0);
    std::vector<MustLocationId> labelLIds(numSubs, 0);

    size_t length = strlen(labelstream.str().c_str()) + 1;
    char* labelsConcat = new char[length];
    strcpy(labelsConcat, labelstream.str().c_str());
    ArcType arcT = ARC_AND;
    if (!myWaitsForAll)
        arcT = ARC_OR;

    (*fSingle)(
        myRank,
        myPId,
        myLId,
        subIdToUse,
        numSubs,
        (int)arcT,
        toRanks.data(),
        labelPIds.data(),
        labelLIds.data(),
        length,
        labelsConcat);

    //==Clean up
    if (labelsConcat)
        delete[] labelsConcat;
}

//=============================
// getUsedComms
//=============================
std::list<I_Comm*> QOpGroupCompletion::getUsedComms(void) { return QOpCompletion::getUsedComms(); }

//=============================
// getUsedWins
//=============================
std::list<I_Win*> QOpGroupCompletion::getUsedWins(void) { return QOpCompletion::getUsedWins(); }

//=============================
// asOpGroupCompletion
//=============================
QOpGroupCompletion* QOpGroupCompletion::asOpGroupCompletion(void)
{
    checkAlive();
    return this;
}

//=============================
// getPingPongNodes
//=============================
std::set<int> QOpGroupCompletion::getPingPongNodes(void)
{
    return QOpCompletion::getPingPongNodes();
}

/*EOF*/
