/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file InfoChecks.cpp
 *       @see MUST::InfoChecks.
 */

#include "GtiMacros.h"
#include "MustEnums.h"
#include "PrefixedOstream.hpp"

#include "InfoChecks.h"

#include <sstream>
#include <unordered_map>

using namespace must;

mGET_INSTANCE_FUNCTION(InfoChecks)
mFREE_INSTANCE_FUNCTION(InfoChecks)
mPNMPI_REGISTRATIONPOINT_FUNCTION(InfoChecks)

//=============================
// Constructor
//=============================
InfoChecks::InfoChecks(const char* instanceName)
    : gti::ModuleBase<InfoChecks, I_InfoChecks>(instanceName)
{
    // create sub modules
    std::vector<I_Module*> subModInstances;
    subModInstances = createSubModuleInstances();

    // handle sub modules
#define NUM_SUBMODULES 5
    if (subModInstances.size() < NUM_SUBMODULES) {
        must::cerr << "Module has not enough sub modules, check its analysis specification! ("
                   << __FILE__ << "@" << __LINE__ << ")" << std::endl;
        assert(0);
    }
    if (subModInstances.size() > NUM_SUBMODULES) {
        for (std::vector<I_Module*>::size_type i = NUM_SUBMODULES; i < subModInstances.size(); i++)
            destroySubModuleInstance(subModInstances[i]);
    }

    myPIdMod = (I_ParallelIdAnalysis*)subModInstances[0];
    myLogger = (I_CreateMessage*)subModInstances[1];
    myArgMod = (I_ArgumentAnalysis*)subModInstances[2];
    myInfoMod = (I_InfoTrack*)subModInstances[3];
    myConstMod = (I_BaseConstants*)subModInstances[4];

    // Initialize module data
    // Nothing to do
}

//=============================
// Destructor
//=============================
InfoChecks::~InfoChecks()
{
    if (myPIdMod)
        destroySubModuleInstance((I_Module*)myPIdMod);
    myPIdMod = NULL;

    if (myLogger)
        destroySubModuleInstance((I_Module*)myLogger);
    myLogger = NULL;

    if (myArgMod)
        destroySubModuleInstance((I_Module*)myArgMod);
    myArgMod = NULL;

    if (myInfoMod)
        destroySubModuleInstance((I_Module*)myInfoMod);
    myInfoMod = NULL;

    if (myConstMod)
        destroySubModuleInstance((I_Module*)myConstMod);
    myConstMod = NULL;
}

//=============================
// errorIfNotKnown
//=============================
GTI_ANALYSIS_RETURN
InfoChecks::errorIfNotKnown(
    MustParallelId pId,
    MustLocationId lId,
    int aId,
    MustInfoType infoHandle)
{
    I_Info* info = myInfoMod->getInfo(pId, infoHandle);

    if (info == NULL) {
        std::stringstream stream;
        stream << "Argument " << myArgMod->getIndex(aId) << " (" << myArgMod->getArgName(aId)
               << ") is an unknown info handle where a valid info handle was expected.";

        myLogger->createMessage(MUST_ERROR_INFO_UNKNWOWN, pId, lId, MustErrorMessage, stream.str());
        return GTI_ANALYSIS_FAILURE;
    }

    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// errorIfNull
//=============================
GTI_ANALYSIS_RETURN
InfoChecks::errorIfNull(MustParallelId pId, MustLocationId lId, int aId, MustInfoType infoHandle)
{
    I_Info* info = myInfoMod->getInfo(pId, infoHandle);

    if (info && info->isNull()) {
        std::stringstream stream;
        stream << "Argument " << myArgMod->getIndex(aId) << " (" << myArgMod->getArgName(aId)
               << ") is MPI_INFO_NULL where a valid info handle was expected.";

        myLogger->createMessage(MUST_ERROR_INFO_NULL, pId, lId, MustErrorMessage, stream.str());
        return GTI_ANALYSIS_FAILURE;
    }

    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// errorIfLengthNotWithinRangeZeroAndLessMaxInfoKey
//=============================
GTI_ANALYSIS_RETURN InfoChecks::errorIfLengthNotWithinRangeZeroAndLessMaxInfoKey(
    MustParallelId pId,
    MustLocationId lId,
    int aId,
    char* key)
{
    auto length = strnlen(key, myConstMod->getMaxInfoKey() + 1);

    if (length > myConstMod->getMaxInfoKey()) {
        std::stringstream stream;
        stream
            << "Argument " << myArgMod->getIndex(aId) << " (" << myArgMod->getArgName(aId)
            << ") is a key, whose length is outside the range of valid values (0-MPI_MAX_INFO_KEY("
            << myConstMod->getMaxInfoKey() << ")), but it is at least " << myArgMod->getArgName(aId)
            << "=" << length << "!";

        myLogger->createMessage(MUST_ERROR_INFO_KEY, pId, lId, MustErrorMessage, stream.str());

        return GTI_ANALYSIS_FAILURE;
    }

    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// errorIfLengthNotWithinRangeZeroAndLessMaxInfoVal
//=============================
GTI_ANALYSIS_RETURN InfoChecks::errorIfLengthNotWithinRangeZeroAndLessMaxInfoVal(
    MustParallelId pId,
    MustLocationId lId,
    int aId,
    char* value)
{
    auto length = strnlen(value, myConstMod->getMaxInfoVal() + 1);
    std::stringstream stream;

    if (length > myConstMod->getMaxInfoVal()) {
        stream << "Argument " << myArgMod->getIndex(aId) << " (" << myArgMod->getArgName(aId)
               << ") is a value, whose length is outside the range of valid values "
                  "(0-MPI_MAX_INFO_VAL("
               << myConstMod->getMaxInfoVal() << ")), but it is at least "
               << myArgMod->getArgName(aId) << "=" << length << "!";

        myLogger->createMessage(MUST_ERROR_INFO_VALUE, pId, lId, MustErrorMessage, stream.str());

        return GTI_ANALYSIS_FAILURE;
    }

    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// errorIfKeyNotDefined
//=============================
GTI_ANALYSIS_RETURN InfoChecks::errorIfKeyNotDefined(
    MustParallelId pId,
    MustLocationId lId,
    int aId,
    char* key,
    MustInfoType infoHandle)
{
    I_Info* info = myInfoMod->getInfo(pId, infoHandle);
    std::unordered_map<std::string, std::string>& keyValues = info->getKeyValPairs();
    auto search = keyValues.find(key);

    std::stringstream stream;

    if (search != keyValues.end()) {
        // if key is defined
        return GTI_ANALYSIS_SUCCESS;
    } else {
        // if key is not defined
        stream << "Key " << key << " is not defined in argument " << myArgMod->getIndex(aId) << " ("
               << myArgMod->getArgName(aId) << ")!";

        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        stream << "(Information on " << myArgMod->getArgName(aId) << ": ";
        info->printInfo(stream, &refs);
        stream << ")";

        myLogger->createMessage(MUST_ERROR_INFO_NOKEY, pId, lId, MustErrorMessage, stream.str());

        return GTI_ANALYSIS_FAILURE;
    }
}

//=============================
// warningIfKeyNotDefined
//=============================
GTI_ANALYSIS_RETURN InfoChecks::warningIfKeyNotDefined(
    MustParallelId pId,
    MustLocationId lId,
    int aId,
    char* key,
    MustInfoType infoHandle)
{
    I_Info* info = myInfoMod->getInfo(pId, infoHandle);
    std::unordered_map<std::string, std::string>& keyValues = info->getKeyValPairs();
    auto search = keyValues.find(key);

    std::stringstream stream;

    if (search != keyValues.end()) {
        // if key is defined
        return GTI_ANALYSIS_SUCCESS;
    } else {
        // if key is not defined
        stream << "Key " << key << " is not defined in argument " << myArgMod->getIndex(aId) << " ("
               << myArgMod->getArgName(aId) << ")!";

        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        stream << "(Information on " << myArgMod->getArgName(aId) << ": ";
        info->printInfo(stream, &refs);
        stream << ")";

        myLogger
            ->createMessage(MUST_WARNING_INFO_NOKEY, pId, lId, MustWarningMessage, stream.str());

        return GTI_ANALYSIS_FAILURE;
    }
}

//=============================
// errorIfNthKeyNotDefined
//=============================
GTI_ANALYSIS_RETURN InfoChecks::errorIfNthKeyNotDefined(
    MustParallelId pId,
    MustLocationId lId,
    int aId,
    int n,
    MustInfoType infoHandle)
{
    I_Info* info = myInfoMod->getInfo(pId, infoHandle);
    std::unordered_map<std::string, std::string>& keyValues = info->getKeyValPairs();
    int numKeys = keyValues.size();

    std::stringstream stream;

    if (n >= numKeys) {
        stream << "Key of order " << n << " is not defined in argument " << myArgMod->getIndex(aId)
               << " (" << myArgMod->getArgName(aId) << ")!";

        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        stream << "(Information on " << myArgMod->getArgName(aId) << ": ";
        info->printInfo(stream, &refs);
        stream << ")";

        myLogger->createMessage(MUST_ERROR_INFO_NOKEY, pId, lId, MustErrorMessage, stream.str());

        return GTI_ANALYSIS_FAILURE;
    }
    auto nthElem = keyValues.begin();
    std::advance(nthElem, n);

    // Copy string
    char* key = strdup(nthElem->first.c_str());

    return errorIfKeyNotDefined(pId, lId, aId, key, infoHandle);
}

//=============================
// warningIfNthKeyNotDefined
//=============================
GTI_ANALYSIS_RETURN InfoChecks::warningIfNthKeyNotDefined(
    MustParallelId pId,
    MustLocationId lId,
    int aId,
    int n,
    MustInfoType infoHandle)
{
    I_Info* info = myInfoMod->getInfo(pId, infoHandle);
    std::unordered_map<std::string, std::string>& keyValues = info->getKeyValPairs();
    int numKeys = keyValues.size();

    std::stringstream stream;

    if (n >= numKeys) {
        stream << "Key of order " << n << " is not defined in argument " << myArgMod->getIndex(aId)
               << " (" << myArgMod->getArgName(aId) << ")!";

        std::list<std::pair<MustParallelId, MustLocationId>> refs;
        stream << "(Information on " << myArgMod->getArgName(aId) << ": ";
        info->printInfo(stream, &refs);
        stream << ")";

        myLogger
            ->createMessage(MUST_WARNING_INFO_NOKEY, pId, lId, MustWarningMessage, stream.str());

        return GTI_ANALYSIS_FAILURE;
    }
    auto nthElem = keyValues.begin();
    std::advance(nthElem, n);

    // Copy string
    char* key = strdup(nthElem->first.c_str());

    return warningIfKeyNotDefined(pId, lId, aId, key, infoHandle);
}