/* Part of the MUST Project, under BSD-3-Clause License
 * See https://hpc.rwth-aachen.de/must/LICENSE for license information.
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include "debug.h"
#include <stdlib.h>
#include <stdarg.h>
#include <stdio.h>
#include <string.h>
#include <strings.h>
#include <sys/syscall.h>
#include <chrono>
#include <unistd.h>
#include <linux/limits.h>
#include <ostream>
#include <iostream>
#include <iomanip>
#include <ctime>
#include <sstream>
#include <fstream>
#include <mutex>
#include <atomic>
#include "mpi.h"

class MustLogStream
{
    std::ofstream byname;
    std::ostream& os;

  public:
    MustLogStream(std::ostream& stream = std::cout) : byname(), os(stream) {}
    MustLogStream(const std::string& filename) : byname(filename), os(this->byname)
    {
        if (!os)
            std::cerr << "[MUST Logging] Could not open debug logging file: " << filename << "."
                      << std::endl;
    }
    std::ostream& get() { return os; }
};

static std::atomic<int> mustDebugLevel{-1};
static int pid = -1;
static std::atomic<int> rank{-1};
static char hostname[1024];
uint64_t mustDebugMask = MUST_ALL;
static bool mustDebugEnableColors = true;
static bool mustDebugEnableAbortOnFatal = true;
static int mustDebugMpiInitialized = 0;
static std::mutex mustDebugMutex;

static __thread int tid = -1;
MustLogStream* mustDebugStream;

// the following are UBUNTU/LINUX, and MacOS ONLY terminal color codes.
#define RESET "\033[0m"
#define BLACK "\033[30m"              /* Black */
#define RED "\033[31m"                /* Red */
#define GREEN "\033[32m"              /* Green */
#define YELLOW "\033[33m"             /* Yellow */
#define BLUE "\033[34m"               /* Blue */
#define MAGENTA "\033[35m"            /* Magenta */
#define CYAN "\033[36m"               /* Cyan */
#define WHITE "\033[37m"              /* White */
#define BOLDBLACK "\033[1m\033[30m"   /* Bold Black */
#define BOLDRED "\033[1m\033[31m"     /* Bold Red */
#define BOLDGREEN "\033[1m\033[32m"   /* Bold Green */
#define BOLDYELLOW "\033[1m\033[33m"  /* Bold Yellow */
#define BOLDBLUE "\033[1m\033[34m"    /* Bold Blue */
#define BOLDMAGENTA "\033[1m\033[35m" /* Bold Magenta */
#define BOLDCYAN "\033[1m\033[36m"    /* Bold Cyan */
#define BOLDWHITE "\033[1m\033[37m"   /* Bold White */

static const char* mustDebugColor2String[] =
    {WHITE, BOLDRED, BOLDRED, BOLDRED, BOLDYELLOW, BOLDWHITE, BOLDWHITE};

static const char* mustDebugLevel2String[] =
    {"NONE", "FATAL", "ERROR", "DEV_ERROR", "WARN", "INFO", "TRACE"};

/* TODO: See whether there is a better way than variable on heap */
__attribute__((destructor)) void mustDebugDestructor(void) { delete mustDebugStream; }

bool getHostName(char* hostname, int maxlen, const char delim)
{
    if (gethostname(hostname, maxlen) != 0) {
        strncpy(hostname, "unknown", maxlen);
        return false;
    }
    int i = 0;
    while ((hostname[i] != delim) && (hostname[i] != '\0') && (i < maxlen - 1))
        i++;
    hostname[i] = '\0';
    return true;
}

inline void initMpiRank()
{
    std::lock_guard<std::mutex> lock{mustDebugMutex};
    if (rank != -1)
        return;
    if (!mustDebugMpiInitialized)
        PMPI_Initialized(&mustDebugMpiInitialized);
    if (mustDebugMpiInitialized) {
        int tmprank;
        PMPI_Comm_rank(MPI_COMM_WORLD, &tmprank);
        rank = tmprank;
    }
}

static void mustDebugInit()
{
    std::lock_guard<std::mutex> lock{mustDebugMutex};
    if (mustDebugLevel != -1)
        return;
    const char* must_debug = getenv("MUST_DEBUG");
    int tempmustDebugLevel = -1;
    if (must_debug == NULL) {
        tempmustDebugLevel = MUST_LOG_ERROR;
    } else if (strcasecmp(must_debug, "FATAL") == 0) {
        tempmustDebugLevel = MUST_LOG_FATAL;
    } else if (strcasecmp(must_debug, "ERROR") == 0) {
        tempmustDebugLevel = MUST_LOG_ERROR;
    } else if (strcasecmp(must_debug, "DEV_ERROR") == 0) {
        tempmustDebugLevel = MUST_LOG_DEV_ERROR;
    } else if (strcasecmp(must_debug, "WARN") == 0) {
        tempmustDebugLevel = MUST_LOG_WARN;
    } else if (strcasecmp(must_debug, "INFO") == 0) {
        tempmustDebugLevel = MUST_LOG_INFO;
    } else if (strcasecmp(must_debug, "TRACE") == 0) {
        tempmustDebugLevel = MUST_LOG_TRACE;
    }

    const char* mustDebugEnableColorsEnv = getenv("MUST_DEBUG_ENABLE_COLORS");
    if (mustDebugEnableColorsEnv && atoi(mustDebugEnableColorsEnv) == 0) {
        mustDebugEnableColors = false;
    }

    const char* mustDebugAbortOnFatalEnv = getenv("MUST_DEBUG_ABORT_ON_FATAL");
    if (mustDebugAbortOnFatalEnv && atoi(mustDebugAbortOnFatalEnv) == 0) {
        mustDebugEnableAbortOnFatal = false;
    }

    pid = getpid();
    if (getenv("MUST_DEBUG_LOG_PATH")) {
        std::stringstream ss;
        ss << getenv("MUST_DEBUG_LOG_PATH") << "." << pid;
        std::cout << "[MUST Debug] Process " << pid << " writing to " << ss.str() << std::endl;
        mustDebugStream = new MustLogStream(ss.str());
    } else {
        mustDebugStream = new MustLogStream();
    }

    /* Parse the MUST_DEBUG_SUBSYS env var
     * This can be a comma separated list such as INIT,COLL
     * or ^INIT,COLL etc
     */
    const char* mustDebugSubsysEnv = getenv("MUST_DEBUG_MODULE");
    if (mustDebugSubsysEnv != NULL) {
        int invert = 0;
        if (mustDebugSubsysEnv[0] == '^') {
            invert = 1;
            mustDebugSubsysEnv++;
        }
        mustDebugMask = invert ? ~0ULL : 0ULL;
        char* mustDebugSubsys = strdup(mustDebugSubsysEnv);
        char* subsys = strtok(mustDebugSubsys, ",");
        while (subsys != NULL) {
            uint64_t mask = 0;
            if (strcasecmp(subsys, "INIT") == 0) {
                mask = MUST_INIT;
            } else if (strcasecmp(subsys, "DLP2P") == 0) {
                mask = MUST_DLP2P;
            } else if (strcasecmp(subsys, "DLCOLL") == 0) {
                mask = MUST_DLCOLL;
            } else if (strcasecmp(subsys, "CHECK") == 0) {
                mask = MUST_CHECK;
            } else if (strcasecmp(subsys, "TRACK") == 0) {
                mask = MUST_TRACK;
            } else if (strcasecmp(subsys, "TSAN") == 0) {
                mask = MUST_TSAN;
            } else if (strcasecmp(subsys, "FIN") == 0) {
                mask = MUST_TSAN;
            } else if (strcasecmp(subsys, "MSG") == 0) {
                mask = MUST_MSG;
            } else if (strcasecmp(subsys, "UNMAPPED") == 0) {
                mask = MUST_UNMAPPED;
            } else if (strcasecmp(subsys, "TYPEART") == 0) {
                mask = MUST_UNMAPPED;
            } else if (strcasecmp(subsys, "DLOTHER") == 0) {
                mask = MUST_DLOTHER;
            } else if (strcasecmp(subsys, "OMP") == 0) {
                mask = MUST_OMP;
            } else if (strcasecmp(subsys, "ALL") == 0) {
                mask = MUST_ALL;
            }
            if (mask) {
                if (invert)
                    mustDebugMask &= ~mask;
                else
                    mustDebugMask |= mask;
            }
            subsys = strtok(NULL, ",");
        }
        free(mustDebugSubsys);
    }

    // Cache pid and hostname
    getHostName(hostname, 1024, '.');
    mustDebugLevel = tempmustDebugLevel;
}

// class NullStream : public std::ostream {
// public:
//     NullStream() : std::ostream(nullptr) {}
//     NullStream(const NullStream &) : std::ostream(nullptr) {}
// };
//
// static NullStream mustDebugNullStream;

inline std::string getCurrentTime()
{
    auto now = std::chrono::system_clock::now();
    auto in_time_t = std::chrono::system_clock::to_time_t(now);

    std::stringstream ss;
    ss << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d.%X");
    return ss.str();
}

bool mustDebugDisplayMsg(mustDebugLogLevel level, unsigned long mask)
{
    if (mustDebugLevel == -1)
        mustDebugInit();
    return mustDebugLevel >= level && mask & mustDebugMask;
}

void mustDebugAbortOnFatal(
    mustDebugLogLevel level,
    unsigned long mask,
    const char* function,
    const char* file,
    int line)
{
    if (mustDebugLevel >= level && level == MUST_LOG_FATAL && mustDebugEnableAbortOnFatal) {
        mustDebugLog(level, mask, function, file, line) << "MUST will now abort!\n";
        mustDebugDestructor();
        abort();
    }
}

std::ostream& mustDebugLog(
    mustDebugLogLevel level,
    unsigned long mask,
    const char* function,
    const char* file,
    int line)
{
    /* We don't need this since displayMsg should always be called first */
    // if (mustDebugLevel == -1) mustDebugInit();
    // bool discardOutput = !(mustDebugLevel >= level && mask & mustDebugMask);
    // if (discardOutput)
    //     return mustDebugNullStream;
    if (rank == -1)
        initMpiRank();

    std::ostream& out = mustDebugStream->get();
    /* Use stringstream to avoid interleaving of writing to cout */
    std::stringstream ss;

    if (tid == -1) {
        tid = syscall(SYS_gettid);
    }

    ss << "[" << getCurrentTime() << " " << rank << ":" << pid << ":" << tid;
    if (mustDebugEnableColors)
        ss << mustDebugColor2String[level];
    ss << " MUST:" << mustDebugLevel2String[level];
    if (mustDebugEnableColors)
        ss << RESET;
    ss << " " << function << "@" << file << ":" << line << "] ";
    if (level == MUST_LOG_FATAL) {
        if (mustDebugEnableColors)
            ss << BOLDBLUE;
        ss << "!Please report this output to must-feedback@lists.rwth-aachen.de! ";
        if (mustDebugEnableColors)
            ss << RESET;
    }
    out << ss.str();
    return out;
}
