/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file InitLocationId.cpp
 *       @see must::InitLocationId.
 *
 *  @date 24.04.2014
 *  @author Tobias Hilbrich
 */

#include <assert.h>
#include <dlfcn.h>

#include "GtiMacros.h"
#include "MustDefines.h"
#include "PrefixedOstream.hpp"
#include <assert.h>
#include <atomic>
#include <pnmpi.h>

#include "InitLocationId.h"
#include "pnmpi/service.h"

#ifdef BUILD_BACKWARD
#include "backward.hpp"
thread_local backward::TraceResolver must::InitLocationId::tr{};
#endif

using namespace must;

mGET_INSTANCE_FUNCTION(InitLocationId)
mFREE_INSTANCE_FUNCTION(InitLocationId)
mPNMPI_REGISTRATIONPOINT_FUNCTION(InitLocationId)

//=============================
// Constructor
//=============================
InitLocationId::InitLocationId(const char* instanceName)
    : gti::ModuleBase<InitLocationId, I_InitLocationId>(instanceName), myKnownLocations(),
      myStackTraceMode(MUST_STACKTRACE_NONE)
{
    // create sub modules
    std::vector<I_Module*> subModInstances;
    subModInstances = createSubModuleInstances();

    // handle sub modules
#define NUM_MODS_REQUIRED 2
    if (subModInstances.size() < NUM_MODS_REQUIRED) {
        must::cerr << "Module has not enough sub modules, check its analysis specification! ("
                   << __FILE__ << "@" << __LINE__ << ")" << std::endl;
        assert(0);
    }
    if (subModInstances.size() > NUM_MODS_REQUIRED) {
        for (std::vector<I_Module*>::size_type i = NUM_MODS_REQUIRED; i < subModInstances.size();
             i++)
            destroySubModuleInstance(subModInstances[i]);
    }

    myPIdInit = (I_InitParallelId*)subModInstances[0];
    myGenLId = (I_GenerateLocationId*)subModInstances[1];

    // Module data
    getWrapperFunction("handleNewLocation", (GTI_Fct_t*)&myNewLocFct);

    if (!myNewLocFct) {
        must::cerr << "InitLocationId module could not find the \"handleNewLocation\" function and "
                      "will not operate correctly as a result. Check the module mappings and "
                      "specifications for this module and the function. Aborting."
                   << std::endl;
        assert(0);
    }

    const char* stripPrefix = std::getenv("MUST_STRIP_PATH_PREFIX");
    if (stripPrefix != nullptr) {
        myStripPrefixStr = stripPrefix;
    }

    const char* stackTraceMode = std::getenv("MUST_STACKTRACE");
    if (stackTraceMode != nullptr) {
        if (strcasecmp(stackTraceMode, "addr2line") == 0) {
            myStackTraceMode = MUST_STACKTRACE_ADDR2LINE;
        } else if (strcasecmp(stackTraceMode, "backward") == 0) {
            myStackTraceMode = MUST_STACKTRACE_BACKWARD;
        }
    }
}

//=============================
// Destructor
//=============================
InitLocationId::~InitLocationId()
{
    if (myPIdInit != nullptr) {
        destroySubModuleInstance((I_Module*)myPIdInit);
        myPIdInit = nullptr;
    }

    if (myGenLId != nullptr) {
        destroySubModuleInstance((I_Module*)myGenLId);
        myGenLId = nullptr;
    }
}

//=============================
// fillCodePtrs
//=============================
bool InitLocationId::fillCodePtrs(LocationInfo& locationInfo)
{
    void *codeptr, *callptr, *baseptr;

    locationInfo.callPtr = nullptr;
    locationInfo.codePtr = nullptr;
    locationInfo.fileName = "\0";
    locationInfo.fileBase = nullptr;

    if (PNMPI_Service_GetReturnAddress(&codeptr) != PNMPI_SUCCESS ||
        PNMPI_Service_GetFunctionAddress(&callptr) != PNMPI_SUCCESS ||
        PNMPI_Service_GetSelfBaseAddress(&baseptr) != PNMPI_SUCCESS) {
        must::cerr << "PnMPI failed to retrieve code pointers" << std::endl;
        return false;
    } else {
        locationInfo.callPtr = callptr;
        // convert codeptr to VMA address to make it compatible with addr2line
        locationInfo.codePtr = (void*)((uintptr_t)codeptr);

        // locationInfo.codePtr = (void*) ((uintptr_t)codeptr);

        /* Get the address information for the process currently running, so the
         * code pointer can be resolved later by the location implementation
         * module. */
        Dl_info info;
        if (dladdr((void*)((uintptr_t)codeptr), &info) == 0) {
            must::cerr << "Failed to call dladdr" << std::endl;
        } else {
            locationInfo.fileName = info.dli_fname;
            locationInfo.fileBase = info.dli_fbase;
        }
        return true;
    }
}

//=============================
// init
//=============================
GTI_ANALYSIS_RETURN InitLocationId::init(MustLocationId* pStorage, const char* callName, int callId)
{
    /* First, check if PnMPI provides the return address for this specific call.
     * It will point to the location, where the MPI call has been issued in the
     * application and can give more accurate information than just the function
     * name.
     *
     * NOTE: This method's execution only proceeds, if this feature is not
     *       available by PnMPI (as the compiler doesn't support it). */
    if (!pStorage)
        return GTI_ANALYSIS_FAILURE;

    MustLocationId id;     // result value
    uint32_t occCount = 0; // occurrence count of this callId

#if defined(BUILD_BACKWARD)
    LocationInfoImpl<uint64_t> thisLocation{};
    LocationInfo thisFullLocation{};
    thisFullLocation.callName = thisLocation.callName = callName;
    fillCodePtrs(thisFullLocation);
    bool newStack = false;

    backward::StackTrace st;
    st.load_here(32);

    for (size_t i = 0; i < st.size(); ++i) {
        thisLocation.stack.push_back((uint64_t)st[i].addr);
    }

    KnownLocationsType::iterator pos;

    pos = myKnownLocations.find(callId);
    if (pos == myKnownLocations.end()) {
        newStack = true;
        id = myGenLId->getNextLocationId();
        // c-1) Its a new location
        occCount = 1;
        std::map<LocationInfoImpl<uint64_t>, MustLocationId> temp;
        temp.insert(std::make_pair(thisLocation, id));
        myKnownLocations.insert(std::make_pair(callId, std::make_pair(temp, occCount)));

    } else {
        // c-2) We have used this call id already (either new or old)
        std::map<LocationInfoImpl<uint64_t>, MustLocationId>::iterator callIdPos;

        callIdPos = pos->second.first.find(thisLocation);
        pos->second.second = pos->second.second + 1;
        occCount = pos->second.second;

        if (callIdPos == pos->second.first.end()) {
            id = myGenLId->getNextLocationId();
            // A new stack
            pos->second.first.insert(std::make_pair(thisLocation, id));
            newStack = true;
        } else {
            // A known stack
            id = callIdPos->second;
        }
    }

    // b) Build the information for this location (callName + stack)
    if (newStack) {
        tr.load_stacktrace(st);

        // first two entries are from backward + MUST itself, start with i = 3
        for (size_t i = 3; i < st.size(); ++i) {
            backward::ResolvedTrace trace = tr.resolve(st[i]);

            // Ignore MPI functions
            auto& object_func = trace.object_function;
            if ((object_func[0] == 'M' || object_func[0] == 'm') &&
                (object_func[1] == 'P' || object_func[1] == 'p') &&
                (object_func[2] == 'I' || object_func[2] == 'i') && object_func[3] == '_')
                continue;

            // Ignore PnMPI functions
            if ((object_func[0] == 'N' || object_func[0] == 'n') &&
                (object_func[1] == 'Q' || object_func[1] == 'q') &&
                (object_func[2] == 'J' || object_func[2] == 'j') && object_func[3] == '_')
                continue;

            // Ignore XMPI functions
            if ((object_func[0] == 'X' || object_func[0] == 'x') &&
                (object_func[1] == 'M' || object_func[1] == 'm') &&
                (object_func[2] == 'P' || object_func[2] == 'p') &&
                (object_func[3] == 'I' || object_func[3] == 'i') && object_func[4] == '_')
                continue;

            // Ignore libc and entry point function
            if (object_func == "__libc_start_main" || object_func == "_start")
                break;

            // Ignore empty strings
            if (object_func.empty())
                continue;

            if (trace.source.function.empty()) {
                // Source information could not be read from the object
                thisFullLocation.stack.emplace_back(trace.object_filename, trace.object_function);
            } else {
                thisFullLocation.stack.emplace_back(
                    stripPathPrefix(myStripPrefixStr, trace.source.filename),
                    trace.source.function,
                    std::to_string(trace.source.line));
            }

            for (const auto& inlined_loc : trace.inliners) {
                thisFullLocation.stack.emplace_back(
                    inlined_loc.filename,
                    inlined_loc.function,
                    std::to_string(inlined_loc.line));
            }

            if (object_func == "main" || object_func == "MAIN__")
                break;
        }
        createHandleNewLocationCall(id, callName, thisFullLocation);
    }
#else
    // Search in the known locations
    KnownLocationsType::iterator pos;
    pos = myKnownLocations.find(callId);

    // First search for name, if there is no entry, create a new one for the call name
    if (pos == myKnownLocations.end()) {
        // Its a new location
        occCount = 1;

        LocationInfo thisLocation;
        thisLocation.callName = callName;
        thisLocation.callPtr = nullptr;
        thisLocation.codePtr = nullptr;
        thisLocation.fileName = "\0";
        thisLocation.fileBase = nullptr;

        if (myStackTraceMode == MUST_STACKTRACE_NONE) {
            // use the GTI call id as location id
            id = callId;
            // dummy map
            std::map<void*, MustLocationId> temp;
            myKnownLocations.insert(std::make_pair(callId, std::make_pair(temp, occCount)));
            createHandleNewLocationCall(id, callName, thisLocation);
        } else if (myStackTraceMode == MUST_STACKTRACE_ADDR2LINE) {
            // generate a new location id for each new location with addr2line
            id = myGenLId->getNextLocationId();
            fillCodePtrs(thisLocation);

            std::map<void*, MustLocationId> temp;
            temp.insert(std::make_pair((void*)thisLocation.codePtr, id));

            myKnownLocations.insert(std::make_pair(callId, std::make_pair(temp, occCount)));
            createHandleNewLocationCall(id, callName, thisLocation);
        } else {
            assert(0);
        }
    } else {
        // Increase occurrence count of this call id
        pos->second.second = pos->second.second + 1;
        occCount = pos->second.second;

        if (myStackTraceMode == MUST_STACKTRACE_NONE) {
            // Use the GTI call id as location id
            id = callId;
        } else if (myStackTraceMode == MUST_STACKTRACE_ADDR2LINE) {
            // There is an entry for this call name, now check if the code pointer already exists.
            void* codePtr = nullptr;
            PNMPI_Service_GetReturnAddress(&codePtr);

            auto elem = pos->second.first.find(codePtr);
            if (elem == pos->second.first.end()) {
                // Code pointer does not exist, create new location
                id = myGenLId->getNextLocationId();

                // Add code pointer to set
                pos->second.first.insert(std::make_pair(codePtr, id));

                LocationInfo thisLocation{};
                thisLocation.callName = callName;

                fillCodePtrs(thisLocation);
                createHandleNewLocationCall(id, callName, thisLocation);
            } else {
                // Code pointer already exists, use existing location id
                id = elem->second;
            }
        } else {
            assert(0);
        }
    }
#endif

    // Store it
    // Lower 32 bit represent the location identifier, upper 32bit represent occurrence count
    *pStorage = (id & 0x00000000FFFFFFFF) | ((uint64_t)occCount << 32);

    return GTI_ANALYSIS_SUCCESS;
}

//=============================
// createHandleNewLocationCall
//=============================

void InitLocationId::createHandleNewLocationCall(
    MustLocationId id,
    const char* callName,
    LocationInfo& location)
{
#ifdef ENABLE_STACKTRACE
    char totalInfo[MUST_MAX_TOTAL_INFO_SIZE];
    int InfoIndices[MUST_MAX_NUM_STACKLEVELS * 3];
    int maxtotalLen = MUST_MAX_TOTAL_INFO_SIZE - MUST_MAX_NUM_STACKLEVELS * 4;
    int totalLength = 0;
    int infoIndicesIndex = 0;

    std::list<MustStackLevelInfo>::iterator iter;
    for (iter = location.stack.begin();
         iter != location.stack.end() && infoIndicesIndex < MUST_MAX_NUM_STACKLEVELS * 3;
         iter++) {
        for (int piece = 0; piece < 3; piece++) {
            const char* info = NULL;

            switch (piece) {
            case 0:
                info = iter->symName.c_str();
                break;
            case 1:
                info = iter->fileModule.c_str();
                break;
            case 2:
                info = iter->lineOffset.c_str();
                break;
            }

            int i = 0;
            while (info && info[i] != '\0' && totalLength < maxtotalLen) {
                totalInfo[totalLength] = info[i];
                i++;
                totalLength++;
            }
            totalInfo[totalLength] = '\0';
            totalLength++;

            InfoIndices[infoIndicesIndex] = totalLength - 1;
            infoIndicesIndex++;
        }
    }
#endif
    MustParallelId pId;
    myPIdInit->init(&pId);
    (*myNewLocFct)(
        pId,
        id,
        callName,
        location.callName.length() + 1,
        location.callPtr,
        location.codePtr,
        location.fileName.c_str(),
        location.fileName.length() + 1,
        location.fileBase
#ifdef ENABLE_STACKTRACE
        ,
        infoIndicesIndex / 3, /*Num stack levels*/
        totalLength,          /*stack infos total length*/
        infoIndicesIndex,     /*indicesLength*/
        InfoIndices,          /*infoIndices*/
        totalInfo             /*StackInfos*/
#endif
    );
}

//=============================
// initCodePtr
//=============================
GTI_ANALYSIS_RETURN InitLocationId::initCodePtr(MustLocationId* pStorage, const void* codeptr_ra)
{
    /* TODO: This is just a dummy implementation to convert the code pointer
     *       into an unique integer. In fact, this acts as an location ID.
     *       However, it should be enriched with additional data like line
     *       numbers in the future. */
    assert(pStorage);
    *pStorage = (MustLocationId)codeptr_ra;
    return GTI_ANALYSIS_SUCCESS;
}

std::string InitLocationId::stripPathPrefix(const std::string& prefix, const std::string& path)
{
    if (!prefix.empty() && path.find(prefix) == 0) {
        // prefix found in path, return stripped path
        return path.substr(prefix.length());
    }
    // no prefix, return usual path
    return path;
}

/*EOF*/
