/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file DGroupOp.cpp
 *       @see must::DGroupOp.
 *
 *  @date 20.01.2012
 *  @author Tobias Hilbrich, Mathias Korepkat, Joachim Protze, Fabian Haensel
 */

#include "DGroupOp.h"
#include <fstream>
#include <sstream>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include "MustDefines.h"
#include "MustOutputdir.h"
#include "Command.hpp"

#define MUST_DEBUG

using namespace must;

DGroupOpPartnerInfo::DGroupOpPartnerInfo(
    DGroupOp* op,
    int targetRank,
    MustLTimeStamp timestamp,
    bool isSend)
    : myTargetRank(targetRank), myTypeHandle(0), myType(NULL), myCount(0), myTimestamp(timestamp),
      myIsSend(isSend), myOp(op)
{
}

DGroupOpPartnerInfo::DGroupOpPartnerInfo(
    DGroupOp* op,
    int targetRank,
    MustDatatypeType typeHandle,
    I_DatatypePersistent* dInfo,
    int count,
    MustLTimeStamp timestamp,
    bool isSend)
    : myTargetRank(targetRank), myTypeHandle(typeHandle), myType(dInfo), myCount(count),
      myTimestamp(timestamp), myIsSend(isSend), myOp(op)
{
    myType->copy();
}

DGroupOpPartnerInfo::~DGroupOpPartnerInfo()
{
    if (myType) {
        myType->erase();
    }
    myType = NULL;
};

bool DGroupOpPartnerInfo::matchTypes(DGroupOpPartnerInfo* other)
{
    if (!other || !myType || !other->myType)
        return false;

    MustMessageIdNames ret = MUST_ERROR_TYPEMATCH_INTERNAL_NOTYPE;
    MustAddressType pos = 0;

    ret = myType->isEqualB(myCount, other->myType, other->myCount, &pos);

    std::stringstream stream;
    std::list<std::pair<MustParallelId, MustLocationId>> references;

    std::string sendOpName = "this operations";
    std::string receiveOpName = "the other operations";

    switch (ret) {
    case MUST_ERROR_TYPEMATCH_MISMATCH:
        stream << "Two neighborhood collective calls cause a type mismatch!";
        break;
    case MUST_ERROR_TYPEMATCH_LENGTH:
        stream << "Two neighborhood collective operations use (datatype,count) pairs that span "
                  "type signatures of different length!"
               << " Each send and receive transfer of a collective call must use equal type "
                  "signatures (I.e. same types with potentially different displacements).";

        break;
    default:
        return true;
    }

    myOp->printHandleInfo(stream, &references);
    stream << " (Information on " << sendOpName << " transfer of count " << myCount
           << " with type:";

    myType->printInfo(stream, &references);
    stream << ")";

    stream << " (Information on " << receiveOpName << " transfer of count " << other->myCount
           << " with type:";

    other->myType->printInfo(stream, &references);

    myOp->myMatcher->myLogger->createMessage(
        ret,
        myOp->getPId(),
        myOp->getLId(),
        MustErrorMessage,
        stream.str(),
        references);

    return true;
}

//=============================
// Constructor
//=============================
DGroupOp::DGroupOp(
    MustParallelId pId,
    MustLocationId lId,
    I_Destructable* persistentHandle,
    DGroupMatch* matcher,
    bool isSend,
    std::vector<int> ranks,
    std::vector<MustLTimeStamp> timestamps,
    int callId,
    bool isCollective)
    : myMatcher(matcher), myIsSend(isSend), myPersistentHandle(persistentHandle), myCallId(callId),
      myPId(pId), myLId(lId), myIsCollective(isCollective)
{
    myPersistentHandle->copy();
    myRank = myMatcher->myPIdMod->getInfoForId(myPId).rank;
    assert(ranks.size() == timestamps.size());
    for (int i = 0; i < ranks.size(); i++)
        myPartnerInfos.push_back(new DGroupOpPartnerInfo(this, ranks[i], timestamps[i], isSend));
}

//=============================
// Constructor
//=============================
DGroupOp::DGroupOp(
    MustParallelId pId,
    MustLocationId lId,
    I_Destructable* persistentHandle,
    DGroupMatch* matcher,
    bool isSend,
    std::vector<int> ranks,
    std::vector<MustLTimeStamp> timestamps,
    MustDatatypeType typeHandle,
    I_DatatypePersistent* dInfo,
    int sendCount,
    int callId,
    bool isCollective)
    : myMatcher(matcher), myIsSend(isSend), myPersistentHandle(persistentHandle), myCallId(callId),
      myPId(pId), myLId(lId), myIsCollective(isCollective)
{
    myPersistentHandle->copy();
    myRank = myMatcher->myPIdMod->getInfoForId(myPId).rank;
    assert(ranks.size() == timestamps.size());
    for (int i = 0; i < ranks.size(); i++) {
        myPartnerInfos.push_back(new DGroupOpPartnerInfo(
            this,
            ranks[i],
            typeHandle,
            dInfo,
            sendCount,
            timestamps[i],
            isSend));
    }
}

//=============================
// Constructor
//=============================
DGroupOp::DGroupOp(
    MustParallelId pId,
    MustLocationId lId,
    I_Destructable* persistentHandle,
    DGroupMatch* matcher,
    bool isSend,
    std::vector<int> ranks,
    std::vector<MustLTimeStamp> timestamps,
    std::vector<MustDatatypeType> typeHandles,
    std::vector<I_DatatypePersistent*> dInfos,
    std::vector<int> sendCounts,
    int callId,
    bool isCollective)
    : myMatcher(matcher), myIsSend(isSend), myPersistentHandle(persistentHandle), myCallId(callId),
      myPId(pId), myLId(lId), myIsCollective(isCollective)
{
    myPersistentHandle->copy();
    myRank = myMatcher->myPIdMod->getInfoForId(myPId).rank;
    assert(ranks.size() == timestamps.size());
    for (int i = 0; i < ranks.size(); i++) {
        myPartnerInfos.push_back(new DGroupOpPartnerInfo(
            this,
            ranks[i],
            typeHandles[i],
            dInfos[i],
            sendCounts[i],
            timestamps[i],
            isSend));
    }
}

//=============================
// Constructor
//=============================
DGroupOp::DGroupOp(
    MustParallelId pId,
    MustLocationId lId,
    I_Destructable* persistentHandle,
    DGroupMatch* matcher,
    bool isSend,
    std::vector<int> ranks,
    std::vector<MustLTimeStamp> timestamps,
    MustDatatypeType typeHandle,
    I_DatatypePersistent* dInfo,
    std::vector<int> sendCounts,
    int callId,
    bool isCollective)
    : myMatcher(matcher), myIsSend(isSend), myPersistentHandle(persistentHandle), myCallId(callId),
      myPId(pId), myLId(lId), myIsCollective(isCollective)
{
    myPersistentHandle->copy();
    myRank = myMatcher->myPIdMod->getInfoForId(myPId).rank;
    assert(ranks.size() == timestamps.size());
    for (int i = 0; i < ranks.size(); i++) {
        myPartnerInfos.push_back(new DGroupOpPartnerInfo(
            this,
            ranks[i],
            typeHandle,
            dInfo,
            sendCounts[i],
            timestamps[i],
            isSend));
    }
}

//=============================
// ~DGroupOp
//=============================
DGroupOp::~DGroupOp(void)
{
    for (auto* info : myPartnerInfos) {
        delete info;
    }
    myPartnerInfos.clear();

    if (myPersistentHandle) {
        myPersistentHandle->erase();
    }
    myPersistentHandle = NULL;

    myMatcher = NULL;
}

PROCESSING_RETURN DGroupOp::process(int rank)
{
    if (myMatcher->findMatchingOp(this)) {
        // This op was matched and completed
        delete (this);
    } else {
        // op was not completed and needs to be stored for future use
        // Add to matching queues
        myMatcher->addOutstandingOp(this);
    }
    return PROCESSING_SUCCESS;
}

//=============================
// print
//=============================
GTI_RETURN DGroupOp::print(std::ostream& out)
{
    if (myIsSend)
        out << "Send";
    else
        out << "Recv";

    // out << " targets=";
    // for (auto& rank : myGroupRanks) {
    //     out << " " << rank;
    // }

    // out << " winSize=" << myWin->getComm()->getGroup()->getSize();

    return GTI_SUCCESS;
}

//=============================
// getPersistentHandleCopy
//=============================
I_Destructable* DGroupOp::getPersistentHandleCopy(void)
{
    myPersistentHandle->copy();
    return myPersistentHandle;
}

//=============================
// getPersistentHandle
//=============================
I_Destructable* DGroupOp::getPersistentHandle(void) { return myPersistentHandle; }

//=============================
// compareHandles
//=============================
bool DGroupOp::compareHandles(DGroupOp* other)
{
    return compareHandles(other->getPersistentHandle());
}

//=============================
// compareHandles
//=============================
bool DGroupOp::compareHandles(I_Destructable* other)
{
    if (I_WinPersistent* otherWin = dynamic_cast<I_WinPersistent*>(other)) {
        if (I_WinPersistent* myWin = dynamic_cast<I_WinPersistent*>(myPersistentHandle)) {
            return myWin->compareWins(otherWin);
        }
    } else if (I_CommPersistent* otherComm = dynamic_cast<I_CommPersistent*>(other)) {
        if (I_CommPersistent* myComm = dynamic_cast<I_CommPersistent*>(myPersistentHandle)) {
            return myComm->compareComms(otherComm);
        }
    }
    return false;
}

//=============================
// printHandleInfo
//=============================
void DGroupOp::printHandleInfo(
    std::stringstream& stream,
    std::list<std::pair<MustParallelId, MustLocationId>>* references)
{
    if (I_WinPersistent* myWin = dynamic_cast<I_WinPersistent*>(myPersistentHandle)) {
        stream << " (Information on window: ";
        myWin->printInfo(stream, references);
        stream << ")";
    } else if (I_CommPersistent* myComm = dynamic_cast<I_CommPersistent*>(myPersistentHandle)) {
        stream << " (Information on communicator: ";
        myComm->printInfo(stream, references);
        stream << ")";
    }
}

//=============================
// getPartnerInfos
//=============================
std::vector<DGroupOpPartnerInfo*>& DGroupOp::getPartnerInfos(void) { return myPartnerInfos; }

//=============================
// isSend
//=============================
bool DGroupOp::isSend(void) { return myIsSend; }

//=============================
// isCollective
//=============================
bool DGroupOp::isCollective(void) { return myIsCollective; };

//=============================
// getCallId
//=============================
int DGroupOp::getCallId(void) { return myCallId; }

//=============================
// getCallId
//=============================
int DGroupOp::getIssuerRank(void) { return myRank; }

//=============================
// getPId
//=============================
MustParallelId DGroupOp::getPId(void) { return myPId; }

//=============================
// getLId
//=============================
MustLocationId DGroupOp::getLId(void) { return myLId; }

//=============================
// copy
//=============================
DGroupOp* DGroupOp::copy(void) { return new DGroupOp(this); }

//=============================
// Constructor (copy from)
//=============================
DGroupOp::DGroupOp(DGroupOp* from)
{
    myMatcher = from->myMatcher;
    myIsSend = from->myIsSend;
    myRank = from->myRank;
    myPartnerInfos = from->myPartnerInfos;
    myIsCollective = from->myIsCollective;
    myCallId = from->myCallId;

    myPersistentHandle = from->myPersistentHandle; /**< The persistent handle of the send/recv, only
                                                      set if not available otherwise.*/
    if (myPersistentHandle)
        myPersistentHandle->copy();

    myPId = from->myPId;
    myLId = from->myLId;
}

//=============================
// printCollectivemismatch
//=============================
bool DGroupOp::printCollectiveMismatch(DGroupOp* other)
{
    // Sanity check
    if (!myIsCollective || !other->myIsCollective)
        return true;
    if (myCallId == other->myCallId)
        return true;

    std::list<std::pair<MustParallelId, MustLocationId>> references;
    std::stringstream stream;

    stream << "A collective mismatch occured (The application executes two different neighborhood "
              "collective "
              "calls on the same communicator)! "
           << "The collective operation that does not matches this operation was executed at "
              "reference 1.";
    references.push_back(std::make_pair(other->myPId, other->myLId));

    stream << " (Information on communicator: ";
    printHandleInfo(stream, &references);
    stream << ")";
    stream << std::endl
           << "Note that collective matching was disabled as a result, collectives won't be "
              "analysed for their correctness or blocking state anymore. You should solve this "
              "issue and rerun your application with MUST.";

    myMatcher->myLogger->createMessage(
        MUST_ERROR_COLLECTIVE_CALL_MISMATCH,
        myPId,
        myLId,
        MustErrorMessage,
        stream.str(),
        references);

    return true;
}

/*EOF*/
