/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file MsgLoggerCommon.cpp
 *       @see MUST::MsgLoggerCommon.
 *
 *  @date 04.01.2023
 *  @author Sebastian Grabowski
 */

#include "MsgLoggerCommon.hpp"
#include "PrefixedOstream.hpp"

#include <sstream>

namespace must
{

auto MsgLoggerBase::format_stacktrace(MustParallelId pId, MustLocationId lId, char const* newline)
    -> std::string
{
    auto lInfo = myLIdModule->getInfoForId(pId, lId);
    std::stringstream out;
    int depth = 0;
    intptr_t resolvedCodePtr = 0;

    switch (myStackTraceMode) {
    case MUST_STACKTRACE_NONE:
        if (lInfo.codePtr != nullptr) {
            resolvedCodePtr = (intptr_t)lInfo.codePtr;
            if ((intptr_t)lInfo.fileBase != (intptr_t)0x400000) {
                // if the filebase is not 0x400000, we have an VMA offset that we have to subtract
                // otherwise, it is a PIE so we can just use the codePtr
                resolvedCodePtr -= (intptr_t)lInfo.fileBase;
            }
            // no resolving of code ptrs, print just the relative address
            out << lInfo.fileName << "+0x" << std::hex << std::setfill('0') << std::setw(16)
                << resolvedCodePtr << std::endl;
        } else {
            out << lInfo.callName;
            printOccurenceCount(out, lId);
        }
        break;
#if defined(CMAKE_ADDR2LINE)
    case MUST_STACKTRACE_ADDR2LINE:
        // use addr2line to resolve the code pointer
        if (lInfo.codePtr && !lInfo.fileName.empty()) {
            resolvedCodePtr = (intptr_t)lInfo.codePtr;
            if ((intptr_t)lInfo.fileBase != (intptr_t)0x400000) {
                // if the filebase is not 0x400000, we have an VMA offset that we have to subtract
                // otherwise, it is a PIE so we can just use the codePtr
                resolvedCodePtr -= (intptr_t)lInfo.fileBase;
            }
            std::ostringstream cmd;
            cmd << CMAKE_ADDR2LINE << " -e " << lInfo.fileName
                << " -f -p -C -i "
                // addr2line only allows addresses without VMA offset (fBase), remove it
                // and subtract 1 because codePtr is a return address
                << ((void*)(resolvedCodePtr - 1));
            out << this->execCommand(cmd.str());
        }
        break;
#endif
#if defined(USE_BACKWARD)
    case MUST_STACKTRACE_BACKWARD:
        for (auto stackIter = lInfo.stack.begin(); stackIter != lInfo.stack.end();
             stackIter++, depth++) {
            if (depth != 0) {
                out << newline;
            }
            out << "#" << depth << "  " << stackIter->symName << "@" << stackIter->fileModule
                << (stackIter->lineOffset.size() ? ":" + stackIter->lineOffset : "");
        }
        break;
#endif
    default:
        must::cerr << "Internal error: Unsupported stacktrace mode selected." << std::endl;
        break;
    }

    return out.str();
}

MsgLoggerBase::MsgLoggerBase()
    : myPIdModule(NULL), myLIdModule(NULL), myStackTraceMode(MUST_STACKTRACE_NONE)
{
    myErrorCodeFile = std::getenv("MUST_RETURNCODE_FILE");

    const char* stackTraceMode = std::getenv("MUST_STACKTRACE");
    if (stackTraceMode != nullptr) {
        if (strcasecmp(stackTraceMode, "addr2line") == 0) {
            myStackTraceMode = MUST_STACKTRACE_ADDR2LINE;
        } else if (strcasecmp(stackTraceMode, "backward") == 0) {
            myStackTraceMode = MUST_STACKTRACE_BACKWARD;
        }
    }
}

bool MsgLoggerBase::emitErrorCodeToFile()
{
    if (!myErrorCodeFile)
        return false;

    std::ofstream fp;
    fp.open(myErrorCodeFile, std::ofstream::out | std::ofstream::app);
    fp << std::to_string(myErrorCode) << std::endl;
    fp.close();
    return true;
}

MsgLoggerBase::~MsgLoggerBase() { emitErrorCodeToFile(); }

void MsgLoggerBase::rememberErrorcode(int msgType)
{
    if (msgType > myErrorCode) {
        myErrorCode = msgType;
        emitErrorCodeToFile();
    }
}

std::string MsgLoggerBase::execCommand(const std::string& cmd) const
{
    const size_t max_buffer = 256;

    std::string data;
    FILE* stream = popen(cmd.c_str(), "r");
    if (stream) {
        char buffer[max_buffer];
        while (!feof(stream))
            if (fgets(buffer, max_buffer, stream) != NULL)
                data.append(buffer);
        pclose(stream);
    }
    return data;
}

void MsgLoggerBase::printOccurenceCount(std::ostream& out, MustLocationId lId)
{
    out << " (" << myLIdModule->getOccurenceCount(lId);

    if (myLIdModule->getOccurenceCount(lId) == 1)
        out << "st";
    else if (myLIdModule->getOccurenceCount(lId) == 2)
        out << "nd";
    else if (myLIdModule->getOccurenceCount(lId) == 3)
        out << "rd";
    else
        out << "th";

    out << " occurrence)";
}

} // namespace must
