/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

/**
 * @file TSanMessages.cpp
 *       @see MUST::TSanMessages.
 *
 *  @date 23.11.2017
 *  @author Joachim Protze, Felix Dommes
 */

#include "TSanMessages.h"

#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include "BaseIds.h"
#include "GtiMacros.h"
#include "GtiTypes.h"
#include "I_Module.h"
#include "MustEnums.h"
#include "PrefixedOstream.hpp"

#include <stdio.h>

using namespace std;
using namespace gti;
using namespace must;

mGET_INSTANCE_FUNCTION(TSanMessages)
mFREE_INSTANCE_FUNCTION(TSanMessages)
mPNMPI_REGISTRATIONPOINT_FUNCTION(TSanMessages)

static bool finalized = false;
static bool disabled = false;

namespace
{

/**
 * Formats a message in MUST style from the Thread Sanitizer's report.
 * @param report the report provided by Thread Sanitizer
 * @return the formatted message string
 */
template <__tsan::TSanVersion T>
auto format(const __tsan::ReportDescT<T>* const report, int fiberNum) -> std::string
{
    stringstream msg;
    const auto* lastMop = (*report->mops.begin_);
    msg << "Detected data race between a " << (lastMop->write ? "write" : "read") << " of size "
        << lastMop->size << " at " << lastMop->stack->frames->info.function << "@1";
    for (auto** it = report->mops.begin_ + 1; it != report->mops.end_; ++it) {
        auto* pMop = *it;
        if (it == report->mops.end_ - 1) {
            msg << " and ";
        } else {
            msg << ", ";
        }
        msg << "a previous " << (pMop->write ? "write" : "read") << " of size " << pMop->size
            << " at ";
        if (pMop->stack != nullptr)
            msg << pMop->stack->frames->info.function << "@"
                << distance(report->mops.begin_, it) + 1;
        else
            msg << "[failed to restore the stack]";
    }
    msg << ".";
    if (fiberNum >= 0) {
        msg << " Asynchronous execution was initiated as " << report->threads.begin_[fiberNum]->name
            << " at reference 2.";
    }
    return msg.str();
}
} // namespace

namespace __tsan
{
/**
 * Called by OnReport() in OnReportLoader.cpp when Thread Sanitizer emits a
 * report.
 *
 * It forwards the report the the TSanMessages instance and prevents thread
 * sanitizer to print its own message to the console.
 */
extern "C" bool TsanOnReport(const __tsan::ReportDesc* rep, bool _suppressed, int llvm_version)
{
    if (disabled)
        return _suppressed;
    auto* pre14Desc = (__tsan::ReportDescT<__tsan::TSanVersion::pre14>*)(rep);
    auto* post14Desc = (__tsan::ReportDescT<__tsan::TSanVersion::post14>*)(rep);
    // we only support data races
    switch (pre14Desc->typ) {
    case __tsan::ReportTypeRace:
        break;
    default:
        return _suppressed;
    }

    INFO(MUST_TSAN, "Received TSan report.");

    // forward data-race report
    if (!finalized) {
        TSanMessages* tsanMessages = TSanMessages::getInstance("");
        if (tsanMessages == nullptr) {
            must::cerr << "TSanMessages instance not found, cannot report TSan message." << endl;
            return _suppressed;
        }
        if (llvm_version < 14)
            tsanMessages->tsanReport(pre14Desc);
        else
            tsanMessages->tsanReport(post14Desc);
        // immediately free instance to ensure that destructor
        // will be called later on when all other modules free
        // TSanMessages
        TSanMessages::freeInstance(tsanMessages);
    }

    // returning true prevents tsan from printing its own report to stdout
    return _suppressed;
    // return false;
}
} /* namespace __tsan */

//=============================
// Constructor
//=============================
TSanMessages::TSanMessages(const char* instanceName)
    : gti::ModuleBase<TSanMessages, I_TSanMessages>(instanceName)
{
    // create sub modules
    vector<I_Module*> subModInstances;
    subModInstances = createSubModuleInstances();

    // handle sub modules
#define NUM_SUBMODULES 4
    if (subModInstances.size() < NUM_SUBMODULES) {
        must::cerr << "Module has not enough sub modules, check its analysis specification! ("
                   << __FILE__ << "@" << __LINE__ << ")" << endl;
        assert(0);
    }
    if (subModInstances.size() > NUM_SUBMODULES) {
        for (vector<I_Module*>::size_type i = NUM_SUBMODULES; i < subModInstances.size(); i++)
            destroySubModuleInstance(subModInstances[i]);
    }

    myLogger = (I_CreateMessage*)subModInstances[0];
    myGenLId = (I_GenerateLocationId*)subModInstances[1];
    myPIdMod = (I_ParallelIdAnalysis*)subModInstances[2];
    myTargetRaceVerifierMod = (I_TargetRaceVerifier*)subModInstances[3];
    // get handleNewLocation function
    getWrapperFunction("handleNewLocation", (GTI_Fct_t*)&myNewLocFunc);

    if (auto env = std::getenv("MUST_DISABLE_TSAN_MESSAGES"))
        disabled = static_cast<bool>(std::atoi(env));

    if (auto env = std::getenv("MUST_DELAY_RACE_REPORTS"))
        myDelayedReport = static_cast<bool>(std::atoi(env));
}

//=============================
// Destructor
//=============================
TSanMessages::~TSanMessages()
{
    if (myLogger != nullptr) {
        destroySubModuleInstance((I_Module*)myLogger);
        myLogger = nullptr;
    }
    if (myGenLId != nullptr) {
        destroySubModuleInstance((I_Module*)myGenLId);
        myGenLId = nullptr;
    }
    if (myPIdMod != nullptr) {
        destroySubModuleInstance((I_Module*)myPIdMod);
        myPIdMod = nullptr;
    }
    if (myTargetRaceVerifierMod != nullptr) {
        destroySubModuleInstance((I_Module*)myTargetRaceVerifierMod);
        myTargetRaceVerifierMod = nullptr;
    }
}

template <__tsan::TSanVersion T>
void debugPrint(const __tsan::ReportDescT<T>* const report)
{
    /*        ReportType typ;
        uptr tag;
        Vector<ReportStack<T>*> stacks;
        Vector<ReportMop<T>*> mops;
        Vector<ReportLocation<T>*> locs;
        Vector<ReportMutex<T>*> mutexes;
        Vector<ReportThread<T>*> threads;
        Vector<Tid> unique_tids;
        ReportStack<T>* sleep;
        int count;
        int signum;*/
    printf(
        "typ = %i, tag = %li, #stacks = %li, #mops = %li, #locs = %li, #mutexes = %li, "
        "#threads = %li, #unique_tids = %li\n",
        report->typ,
        report->tag,
        report->stacks.Size(),
        report->mops.Size(),
        report->locs.Size(),
        report->mutexes.Size(),
        report->threads.Size(),
        report->unique_tids.Size());
    for (auto** it = report->threads.begin_; it != report->threads.end_; ++it) {
        printf("Thread %i, %s\n", (*it)->id, (*it)->name);
    }
}

/**
 * Builds up a stack trace string from those given by the report.
 *
 * This constructs a stacktrace in a format that can be passed to the MUST
 * message mechanism. The format is built up as one big non-zero-terminated
 * string. It consists of the concatenated symbolname, filename and linenumber
 * for each stack level.
 * The `indices` point to the last character of each component in `locstrings`,
 * such that indices[i] is the offset in locStrings to the last character of
 * the i-th substring.
 *
 * @param stack the stack trace
 * @param locStrings[out] whole concatenated string
 * @param indices[out] the indices that mark the end of each substring of
 *                     locStrings
 * @param stack_depth[out] the stack strace's height
 */
template <__tsan::TSanVersion T>
void build_stacktrace(
    const __tsan::SymbolizedStack<T>* const stack,
    std::string& locString,
    std::vector<int>& indices,
    int& stack_depth)
{
    stack_depth = 0;
    locString = "";
    indices.clear();
    for (auto* pFrame = stack; pFrame != nullptr; pFrame = pFrame->next) {
        // symbol name
        if (!pFrame->info.function)
            continue;
        locString.append(pFrame->info.function);
        indices.emplace_back(locString.length() - 1);
        stack_depth += 1;
        // file/module name
        if (pFrame->info.file != nullptr) {
            locString.append(pFrame->info.file);
        } else if (pFrame->info.module != nullptr) {
            locString.append(pFrame->info.module);
        } else {
            indices.emplace_back(locString.length() - 1);
            indices.emplace_back(locString.length() - 1);
            continue;
        }
        indices.emplace_back(locString.length() - 1);
        // line/module offset number
        if (pFrame->info.line != 0) {
            locString.append(to_string(pFrame->info.line));
        } else {
            std::stringstream tmp{};
            tmp << "0x" << std::hex << pFrame->info.module_offset;
            locString.append(tmp.str());
        }
        indices.emplace_back(locString.length() - 1);
    }
}

/**
 * Creates a Location from a TSan SymbolizedStack
 *
 * @param pId[in] the process Id
 * @param stack[in] the stack trace
 * @param lId[out] a new Location Id for the generated location
 */
template <__tsan::TSanVersion T>
void TSanMessages::createLocation(
    MustParallelId pId,
    const __tsan::SymbolizedStack<T>* const stack,
    MustLocationId& lId)
{
    lId = myGenLId->getNextLocationId();

    // build the stack info from the report
    std::string stack_string{};
#ifdef ENABLE_STACKTRACE
    std::vector<int> stack_indices{};
    int levels = 0;
    build_stacktrace(stack, stack_string, stack_indices, levels);
#endif

    LocationInfo locInfo{};
    if (stack->info.function) {
        locInfo.callName = stack->info.function;
        locInfo.codePtr = (void*)(stack->info.module_offset + 1);
        locInfo.fileName = stack->info.module;
    }

    if (myDelayedReport) {
        delayedLocations.emplace_back(
            DelayedLocationInfo{locInfo, levels, stack_indices, stack_string, lId, pId});
    } else {
        (*myNewLocFunc)(
            pId,
            lId,
            locInfo.callName.c_str(),
            locInfo.callName.length() + 1,
            NULL,
            locInfo.codePtr,
            locInfo.fileName.c_str(),
            locInfo.fileName.length() + 1,
            NULL
#ifdef ENABLE_STACKTRACE
            ,
            levels,
            stack_string.length() + 1, // the length of all concatenated stack strings
            stack_indices.size(),      // the number of indices (this is claimed to be convenient)
            stack_indices.data(),
            stack_string.c_str() // should only be read access
#endif
        );
    }
}

/**
 * Formats a race message in MUST style from the Thread Sanitizer's report together with a list
 * list of references. The list of references is used to extract further information, e.g., the
 * remote rank in case of a remote data race.
 * @param report the report provided by Thread Sanitizer
 * @param refList list of (pid, lid) references to extract further information (e.g. remote ranks)
 * @return the formatted message string
 */
template <__tsan::TSanVersion T>
auto TSanMessages::formatRMA(
    const __tsan::ReportDescT<T>* const report,
    std::list<std::pair<MustParallelId, MustLocationId>>& refList) -> std::string
{
    std::stringstream msg;

    // extract list of ranks involved
    std::vector<int> rankList;
    std::transform(
        refList.begin(),
        refList.end(),
        std::back_inserter(rankList),
        [this](std::pair<MustParallelId, MustLocationId> ref) {
            return myPIdMod->getInfoForId(ref.first).rank;
        });

    // obtain own pId
    MustParallelId localpId = 0;
    getNodeInLayerId(&localpId);
    int localRank = myPIdMod->getInfoForId(localpId).rank;

    // If all ranks are the same: Local buffer race, otherwise remote data race
    bool isLocalRace = std::all_of(rankList.begin(), rankList.end(), [&](int rank) {
        return rank == rankList.front();
    });
    if (isLocalRace) {
        msg << "Local buffer data race at rank " << localRank << " between a ";
    } else {
        msg << "Remote data race at rank " << localRank << " between a ";
    }

    const auto* lastMop = (*report->mops.begin_);
    msg << (lastMop->write ? "write" : "read") << " of size " << lastMop->size << " at "
        << lastMop->stack->frames->info.function << "@1";
    if (!isLocalRace)
        msg << " from rank " << rankList.front();

    auto** it = report->mops.begin_ + 1;
    auto rank_it = rankList.begin() + 1;
    for (; it != report->mops.end_; ++it, ++rank_it) {
        auto* pMop = *it;
        if (it == report->mops.end_ - 1) {
            msg << " and ";
        } else {
            msg << ", ";
        }
        msg << "a previous " << (pMop->write ? "write" : "read") << " of size " << pMop->size
            << " at ";
        if (pMop->stack != nullptr)
            msg << pMop->stack->frames->info.function << "@"
                << distance(report->mops.begin_, it) + 1;
        else
            msg << "[failed to restore the stack]";
        if (!isLocalRace)
            msg << " from rank " << *rank_it;
    }
    msg << ".";
    return msg.str();
}

/**
 * TSan report handler for RMA races
 * Should only be called if a RMA fiber is involved (Fiber name "RMAFiber")
 * @param report the report provided by Thread Sanitizer
 */
template <__tsan::TSanVersion T>
auto TSanMessages::tsanReportRMA(const __tsan::ReportDescT<T>* const report) -> GTI_ANALYSIS_RETURN
{
    // obtain IDs and register the locations
    MustLocationId lId = 0;
    int fiberNum = -1, fn = 0;

    auto lIdList = std::vector<MustParallelId>{};
#ifdef ENABLE_STACKTRACE
    auto locStrings = std::vector<std::string>{};
    auto indices = std::vector<std::vector<int>>{};
    auto stackLevels = std::vector<int>{};
#endif

    auto refList = list<pair<MustParallelId, MustLocationId>>{};
    auto rmaConflictPair = vector<pair<MustParallelId, MustRMAId>>();

    // obtain IDs and register the locations
    MustParallelId pId = 0;
    // We can derive *our* pId by using the same approach as in InitParallelIdHybrid.cpp
    // without requiring InitParallelId to be called itself (otherwise the TSanMessages
    // module cannot run on tool thread layer).
    getNodeInLayerId(&pId);

    // extract locations
    auto locations = std::vector<LocationInfo>{};
    for (auto** it = report->mops.begin_; it != report->mops.end_; ++it) {
        auto* pMop = *it;
        auto stack_string = std::string{};
        auto stack_indices = std::vector<int>{};
        int levels = 0;
        auto stack = (*it)->stack;

        if (stack == nullptr)
            continue;

        // Check if we have a stack trace in the report that encodes (pId, lId) via delimiter
        // 0x0FFFFFFFFFFFFFFF. Such stack traces are of the format #0 - any frame #1 -
        // 0x0FFFFFFFFFFFFFFF #2 - pId #3 - lId
        if (stack->frames->next != nullptr &&
            stack->frames->next->info.address == 0x0FFFFFFFFFFFFFFF) {
            MustRMAId rmaId = stack->frames->next->next->info.address;
            auto pId = stack->frames->next->next->next->info.address;
            auto lId = stack->frames->next->next->next->next->info.address;
            refList.push_back(std::make_pair(pId, lId));
            // 0x000000000FFFFFFF is the empty delimiter for local buffer annotations
            if (rmaId != 0x000000000FFFFFFF) {
                // found conflicting remote RMA operation, this has to be further invesigated by
                // TargetChecks module
                rmaConflictPair.push_back(std::make_pair(pId, rmaId));
            }
        } else { // otherwise, do the usual processing by parsing the TSan stacktrace
            char* loc = pMop->stack->frames->info.function;
            LocationInfo locInfo;
            if (loc != nullptr)
                locInfo.callName = loc;
            else
                locInfo.callName = "";
            locInfo.codePtr = (void*)(stack->frames->info.module_offset);
            locInfo.fileName = stack->frames->info.module;
            locations.emplace_back(locInfo);
#ifdef ENABLE_STACKTRACE
            build_stacktrace(stack->frames, stack_string, stack_indices, levels);
            locStrings.push_back(stack_string);
            indices.push_back(stack_indices);
            stackLevels.push_back(levels);
#endif
            MustLocationId location_id = myGenLId->getNextLocationId();
            lIdList.push_back(location_id);
            refList.emplace_back(make_pair(pId, location_id));
        }
    }

    if (rmaConflictPair.size() >= 2) {
        if (!myTargetRaceVerifierMod->isRace(
                rmaConflictPair[0].first,
                rmaConflictPair[0].second,
                rmaConflictPair[1].first,
                rmaConflictPair[1].second)) {
            // no race, do not report anything
            return GTI_ANALYSIS_SUCCESS;
        }
    }

    for (std::size_t i = 0; i < locations.size(); i++) {
        if (myDelayedReport) {
            delayedLocations.emplace_back(DelayedLocationInfo{
                locations[i],
                stackLevels[i],
                indices[i],
                locStrings[i],
                lIdList[i],
                pId});
        } else {
            (*myNewLocFunc)(
                pId,
                lIdList[i],
                locations[i].callName.c_str(),
                locations[i].callName.length() + 1,
                NULL,
                locations[i].codePtr,
                locations[i].fileName.c_str(),
                locations[i].fileName.length() + 1,
                NULL
#ifdef ENABLE_STACKTRACE
                ,
                stackLevels[i],
                locStrings[i].length() + 1, // the length of all concatenated stack strings
                indices[i].size(), // the number of indices (this is claimed to be convenient)
                indices[i].data(),
                locStrings[i].c_str() // should only be read access
#endif
            );
        }
    }

    // format and create some nice message
    std::stringstream msg;
    msg << formatRMA(report, refList);

    int refNum = 1;
    // generate additional information for remote accesses with concurrent regions
    for (const auto& op : rmaConflictPair) {
        RMAOpHistoryData opData;
        myTargetRaceVerifierMod->getHistoryData(op.first, op.second, opData);
        msg << " Concurrent region of reference " << refNum << " started at reference "
            << refList.size() + 1 << " and ended at reference " << refList.size() + 2 << ". "
            << endl;
        refList.emplace_back(make_pair(pId, opData.lIdStart));
        refList.emplace_back(make_pair(pId, opData.lIdEnd));
        refNum++;
    }

    if (myDelayedReport) {
        delayedMessages.emplace_back(
            DelayedMesssage{MUST_WARNING_DATARACE, pId, 0, MustErrorMessage, msg.str(), refList});
    } else {
        myLogger
            ->createMessage(MUST_WARNING_DATARACE, pId, 0, MustErrorMessage, msg.str(), refList);
    }
    return gti::GTI_ANALYSIS_SUCCESS;
}

//=============================
// tsanReport
//=============================
template <__tsan::TSanVersion T>
auto TSanMessages::tsanReport(const __tsan::ReportDescT<T>* const report) -> GTI_ANALYSIS_RETURN
{
    for (auto** it = report->threads.begin_; it != report->threads.end_; ++it) {
        __tsan::ReportThread<T>* pThread = *it;
        if (pThread->thread_type == __tsan::ThreadType::Fiber) {
            std::string fiberName = pThread->name;
            if (fiberName.find("RMAFiber") != std::string::npos) {
                // RMA Race, special handling required
                return tsanReportRMA(report);
            }
        }
    }

    debugPrint(report);

    // obtain IDs and register the locations
    MustParallelId pId = 0;
    MustLocationId lId = 0;
    getNodeInLayerId(&pId);
    std::list<std::pair<MustParallelId, MustLocationId>> refList{};
    int fiberNum = -1, fn = 0;

    for (auto** it = report->mops.begin_; it != report->mops.end_; ++it) {
        auto* pMop = *it;
        if (pMop->stack != nullptr) {
            createLocation(pId, pMop->stack->frames, lId);
            refList.emplace_back(make_pair(pId, lId));
        }
    }

    for (auto** it = report->threads.begin_; it != report->threads.end_; ++it, ++fn) {
        auto* pThread = *it;
        if (pThread->thread_type == __tsan::ThreadType::Fiber && pThread->stack != nullptr &&
            pThread->stack->frames != nullptr && pThread->stack->frames->next != nullptr) {
            lId = myGenLId->getNextLocationId();
            createLocation(pId, pThread->stack->frames->next, lId);
            refList.emplace_back(make_pair(pId, lId));
            fiberNum = fn;
        }
    }

    // format and create some nice message
    auto msg = format(report, fiberNum);
    auto topRef = refList.front();
    refList.pop_front();
    if (myDelayedReport) {
        delayedMessages.emplace_back(DelayedMesssage{
            MUST_WARNING_DATARACE,
            topRef.first,
            topRef.second,
            MustErrorMessage,
            msg,
            refList});
    } else {
        myLogger->createMessage(
            MUST_WARNING_DATARACE,
            topRef.first,
            topRef.second,
            MustErrorMessage,
            msg,
            refList);
    }
    return GTI_ANALYSIS_SUCCESS;
}

GTI_ANALYSIS_RETURN TSanMessages::fini()
{
    if (myDelayedReport) {
        // Flush delayed locations
        for (const auto& delayedLoc : delayedLocations) {
            (*myNewLocFunc)(
                delayedLoc.pId,
                delayedLoc.lId,
                delayedLoc.locInfo.callName.c_str(),
                delayedLoc.locInfo.callName.length() + 1,
                NULL,
                delayedLoc.locInfo.codePtr,
                delayedLoc.locInfo.fileName.c_str(),
                delayedLoc.locInfo.fileName.length() + 1,
                NULL
#ifdef ENABLE_STACKTRACE
                ,
                delayedLoc.stack_levels,
                delayedLoc.stack_string.length() +
                    1, // the length of all concatenated stack strings
                delayedLoc.stack_indices
                    .size(), // the number of indices (this is claimed to be convenient)
                (int*)delayedLoc.stack_indices.data(),
                delayedLoc.stack_string.c_str() // should only be read access
#endif
            );
        }

        // Flush delayed messages
        for (const auto& delayedMsg : delayedMessages) {
            myLogger->createMessage(
                delayedMsg.msgId,
                delayedMsg.pId,
                delayedMsg.lId,
                delayedMsg.msgType,
                delayedMsg.text,
                delayedMsg.refLocations);
        }
    }

    finalized = true;
    return GTI_ANALYSIS_SUCCESS;
}
