#!/usr/bin/env python3
# pylint: disable=C0103,C0114,C0116,C0209,C0301,R0914,R0912,R0915,W0511,W0603,eval-used
######################################################################

import argparse
import bisect
import collections
import math
import re
import statistics
# from pprint import pprint

Sections = []
LongestVcdStrValueLength = 0
Threads = collections.defaultdict(lambda: [])  # List of records per thread id
Mtasks = collections.defaultdict(lambda: {'elapsed': 0, 'end': 0})
Cpus = collections.defaultdict(lambda: {'mtask_time': 0})
Global = {'args': {}, 'cpuinfo': collections.defaultdict(lambda: {}), 'stats': {}}
ElapsedTime = None  # total elapsed time
ExecGraphTime = 0  # total elapsed time excuting an exec graph
ExecGraphIntervals = []  # list of (start, end) pairs

######################################################################


def read_data(filename):
    with open(filename, "r", encoding="utf8") as fh:
        re_thread = re.compile(r'^VLPROFTHREAD (\d+)$')
        re_record = re.compile(r'^VLPROFEXEC (\S+) (\d+)(.*)$')
        re_payload_mtaskBegin = re.compile(r'id (\d+) predictStart (\d+) cpu (\d+)')
        re_payload_mtaskEnd = re.compile(r'id (\d+) predictCost (\d+)')

        re_arg1 = re.compile(r'VLPROF arg\s+(\S+)\+([0-9.]*)\s*')
        re_arg2 = re.compile(r'VLPROF arg\s+(\S+)\s+([0-9.]*)\s*$')
        re_stat = re.compile(r'VLPROF stat\s+(\S+)\s+([0-9.]+)')
        re_proc_cpu = re.compile(r'VLPROFPROC processor\s*:\s*(\d+)\s*$')
        re_proc_dat = re.compile(r'VLPROFPROC ([a-z_ ]+)\s*:\s*(.*)$')
        cpu = None
        thread = None
        execGraphStart = None

        global LongestVcdStrValueLength
        global ExecGraphTime

        SectionStack = []
        mTaskThread = {}

        for line in fh:
            recordMatch = re_record.match(line)
            if recordMatch:
                kind, tick, payload = recordMatch.groups()
                tick = int(tick)
                payload = payload.strip()
                if kind == "SECTION_PUSH":
                    LongestVcdStrValueLength = max(LongestVcdStrValueLength, len(payload))
                    SectionStack.append(payload)
                    Sections.append((tick, tuple(SectionStack)))
                elif kind == "SECTION_POP":
                    assert SectionStack, "SECTION_POP without SECTION_PUSH"
                    SectionStack.pop()
                    Sections.append((tick, tuple(SectionStack)))
                elif kind == "MTASK_BEGIN":
                    mtask, predict_start, ecpu = re_payload_mtaskBegin.match(payload).groups()
                    mtask = int(mtask)
                    predict_start = int(predict_start)
                    ecpu = int(ecpu)
                    mTaskThread[mtask] = thread
                    records = Threads[thread]
                    assert not records or records[-1]['start'] <= records[-1]['end'] <= tick
                    records.append({
                        'start': tick,
                        'mtask': mtask,
                        'predict_start': predict_start,
                        'cpu': ecpu
                    })
                    Mtasks[mtask]['begin'] = tick
                    Mtasks[mtask]['thread'] = thread
                    Mtasks[mtask]['predict_start'] = predict_start
                elif kind == "MTASK_END":
                    mtask, predict_cost = re_payload_mtaskEnd.match(payload).groups()
                    mtask = int(mtask)
                    predict_cost = int(predict_cost)
                    begin = Mtasks[mtask]['begin']
                    record = Threads[mTaskThread[mtask]][-1]
                    record['end'] = tick
                    record['predict_cost'] = predict_cost
                    Mtasks[mtask]['elapsed'] += tick - begin
                    Mtasks[mtask]['predict_cost'] = predict_cost
                    Mtasks[mtask]['end'] = max(Mtasks[mtask]['end'], tick)
                elif kind == "EXEC_GRAPH_BEGIN":
                    execGraphStart = tick
                elif kind == "EXEC_GRAPH_END":
                    ExecGraphTime += tick - execGraphStart
                    ExecGraphIntervals.append((execGraphStart, tick))
                    execGraphStart = None
                elif Args.debug:
                    print("-Unknown execution trace record: %s" % line)
            elif re_thread.match(line):
                thread = int(re_thread.match(line).group(1))
            elif re.match(r'^VLPROF(THREAD|VERSION)', line):
                pass
            elif re_arg1.match(line):
                match = re_arg1.match(line)
                Global['args'][match.group(1)] = match.group(2)
            elif re_arg2.match(line):
                match = re_arg2.match(line)
                Global['args'][match.group(1)] = match.group(2)
            elif re_stat.match(line):
                match = re_stat.match(line)
                Global['stats'][match.group(1)] = match.group(2)
            elif re_proc_cpu.match(line):
                match = re_proc_cpu.match(line)
                cpu = int(match.group(1))
            elif cpu is not None and re_proc_dat.match(line):
                match = re_proc_dat.match(line)
                term = match.group(1)
                value = match.group(2)
                term = re.sub(r'\s+$', '', term)
                term = re.sub(r'\s+', '_', term)
                value = re.sub(r'\s+$', '', value)
                Global['cpuinfo'][cpu][term] = value
            elif re.match(r'^#', line):
                pass
            elif Args.debug:
                print("-Unk: %s" % line)


def re_match_result(regexp, line, result_to):
    result_to = re.match(regexp, line)
    return result_to


######################################################################


def report():
    print("Verilator Gantt report")

    print("\nArgument settings:")
    for arg in sorted(Global['args'].keys()):
        plus = "+" if re.match(r'^\+', arg) else " "
        print("  %s%s%s" % (arg, plus, Global['args'][arg]))

    for records in Threads.values():
        for record in records:
            cpu = record['cpu']
            elapsed = record['end'] - record['start']
            Cpus[cpu]['mtask_time'] += elapsed

    global ElapsedTime
    ElapsedTime = int(Global['stats']['ticks'])
    nthreads = int(Global['stats']['threads'])
    ncpus = max(len(Cpus), 1)

    print("\nSummary:")
    print("  Total elapsed time = {} rdtsc ticks".format(ElapsedTime))
    print("  Parallelized code  = {:.2%} of elapsed time".format(ExecGraphTime / ElapsedTime))
    print("  Total threads      = %d" % nthreads)
    print("  Total CPUs used    = %d" % ncpus)
    print("  Total mtasks       = %d" % len(Mtasks))
    print("  Total yields       = %d" % int(Global['stats'].get('yields', 0)))

    report_mtasks()
    report_cpus()
    report_sections()

    if nthreads > ncpus:
        print()
        print("%%Warning: There were fewer CPUs (%d) than threads (%d)." % (ncpus, nthreads))
        print("        : See docs on use of numactl.")
    else:
        if 'cpu_socket_cores_warning' in Global:
            print()
            print("%Warning: Multiple threads scheduled on same hyperthreaded core.")
            print("        : See docs on use of numactl.")
        if 'cpu_sockets_warning' in Global:
            print()
            print("%Warning: Threads scheduled on multiple sockets.")
            print("        : See docs on use of numactl.")
    print()


def report_mtasks():
    if not Mtasks:
        return

    nthreads = int(Global['stats']['threads'])

    # If we know cycle time in the same (rdtsc) units,
    # this will give us an actual utilization number,
    # (how effectively we keep the cores busy.)
    #
    # It also gives us a number we can compare against
    # serial mode, to estimate the overhead of data sharing,
    # which will show up in the total elapsed time. (Overhead
    # of synchronization and scheduling should not.)
    total_mtask_time = 0
    thread_mtask_time = collections.defaultdict(lambda: 0)
    long_mtask_time = 0
    long_mtask = None
    predict_mtask_time = 0
    predict_elapsed = 0
    for mtaskId in Mtasks:
        record = Mtasks[mtaskId]
        predict_mtask_time += record['predict_cost']
        total_mtask_time += record['elapsed']
        thread_mtask_time[record['thread']] += record['elapsed']
        predict_end = record['predict_start'] + record['predict_cost']
        predict_elapsed = max(predict_elapsed, predict_end)
        if record['elapsed'] > long_mtask_time:
            long_mtask_time = record['elapsed']
            long_mtask = mtaskId
    Global['predict_last_end'] = predict_elapsed

    serialTime = ElapsedTime - ExecGraphTime

    def subReport(elapsed, work):
        print("  Thread utilization = {:7.2%}".format(work / (elapsed * nthreads)))
        print("  Speedup            = {:6.3}x".format(work / elapsed))

    print("\nParallelized code, measured:")
    subReport(ExecGraphTime, total_mtask_time)

    print("\nParallelized code, predicted during static scheduling:")
    subReport(predict_elapsed, predict_mtask_time)

    print("\nAll code, measured:")
    subReport(ElapsedTime, serialTime + total_mtask_time)

    print("\nAll code, measured, scaled by predicted speedup:")
    expectedParallelSpeedup = predict_mtask_time / predict_elapsed
    scaledElapsed = serialTime + total_mtask_time / expectedParallelSpeedup
    subReport(scaledElapsed, serialTime + total_mtask_time)

    p2e_ratios = []
    min_p2e = 1000000
    min_mtask = None
    max_p2e = -1000000
    max_mtask = None

    for mtask in sorted(Mtasks.keys()):
        if Mtasks[mtask]['elapsed'] > 0:
            if Mtasks[mtask]['predict_cost'] == 0:
                Mtasks[mtask]['predict_cost'] = 1  # don't log(0) below
            p2e_ratio = math.log(Mtasks[mtask]['predict_cost'] / Mtasks[mtask]['elapsed'])
            p2e_ratios.append(p2e_ratio)

            if p2e_ratio > max_p2e:
                max_p2e = p2e_ratio
                max_mtask = mtask
            if p2e_ratio < min_p2e:
                min_p2e = p2e_ratio
                min_mtask = mtask

    print("\nMTask statistics:")
    print("  Longest mtask id = {}".format(long_mtask))
    print("  Longest mtask time = {:.2%} of time elapsed in parallelized code".format(
        long_mtask_time / ExecGraphTime))
    print("  min log(p2e) = %0.3f" % min_p2e, end="")

    print("  from mtask %d (predict %d," % (min_mtask, Mtasks[min_mtask]['predict_cost']), end="")
    print(" elapsed %d)" % Mtasks[min_mtask]['elapsed'])
    print("  max log(p2e) = %0.3f" % max_p2e, end="")
    print("  from mtask %d (predict %d," % (max_mtask, Mtasks[max_mtask]['predict_cost']), end="")
    print(" elapsed %d)" % Mtasks[max_mtask]['elapsed'])

    stddev = statistics.pstdev(p2e_ratios)
    mean = statistics.mean(p2e_ratios)
    print("  mean = %0.3f" % mean)
    print("  stddev = %0.3f" % stddev)
    print("  e ^ stddev = %0.3f" % math.exp(stddev))


def report_cpus():
    print("\nCPU info:")

    Global['cpu_sockets'] = collections.defaultdict(lambda: 0)
    Global['cpu_socket_cores'] = collections.defaultdict(lambda: 0)

    print("   Id | Time spent executing MTask | Socket | Core | Model")
    print("      | % of elapsed ticks / ticks |        |      |")
    print("  ====|============================|========|======|======")
    for cpu in sorted(Cpus):
        socket = ""
        core = ""
        model = ""
        if cpu in Global['cpuinfo']:
            cpuinfo = Global['cpuinfo'][cpu]
            if 'physical_id' in cpuinfo and 'core_id' in cpuinfo:
                socket = cpuinfo['physical_id']
                Global['cpu_sockets'][socket] += 1
                core = cpuinfo['core_id']
                Global['cpu_socket_cores'][socket + "__" + core] += 1

            if 'model_name' in cpuinfo:
                model = cpuinfo['model_name']

        print("  {:3d} | {:7.2%} / {:16d} | {:>6s} | {:>4s} | {}".format(
            cpu, Cpus[cpu]['mtask_time'] / ElapsedTime, Cpus[cpu]['mtask_time'], socket, core,
            model))

    if len(Global['cpu_sockets']) > 1:
        Global['cpu_sockets_warning'] = True
        for scn in Global['cpu_socket_cores'].values():
            if scn > 1:
                Global['cpu_socket_cores_warning'] = True


def report_sections():
    if not Sections:
        return
    print("\nSection profile:")

    totalTime = collections.defaultdict(lambda: 0)
    selfTime = collections.defaultdict(lambda: 0)

    sectionTree = [0, {}, 1]  # [selfTime, childTrees, numberOfTimesEntered]
    prevTime = 0
    prevStack = ()
    for time, stack in Sections:
        if len(stack) > len(prevStack):
            scope = sectionTree
            for item in stack:
                scope = scope[1].setdefault(item, [0, {}, 0])
            scope[2] += 1
        dt = time - prevTime
        scope = sectionTree
        for item in prevStack:
            scope = scope[1].setdefault(item, [0, {}, 0])
        scope[0] += dt

        if prevStack:
            for name in prevStack:
                totalTime[name] += dt
            selfTime[prevStack[-1]] += dt
        prevTime = time
        prevStack = stack

    def treeSum(tree):
        n = tree[0]
        for subTree in tree[1].values():
            n += treeSum(subTree)
        return n

    # Make sure the tree sums to the elapsed time
    sectionTree[0] += ElapsedTime - treeSum(sectionTree)

    def printTree(prefix, name, entries, tree):
        print("  {:7.2%} | {:7.2%} | {:8} | {:10.2f} | {}".format(
            treeSum(tree) / ElapsedTime, tree[0] / ElapsedTime, tree[2], tree[2] / entries,
            prefix + name))
        for k in sorted(tree[1], key=lambda _: -treeSum(tree[1][_])):
            printTree(prefix + "  ", k, tree[2], tree[1][k])

    print(" Total    | Self    | Total    | Relative   | Section")
    print(" time     | time    | entries  | entries    |  name  ")
    print("==========|=========|==========|============|========")
    printTree("", "*TOTAL*", 1, sectionTree)


######################################################################


def write_vcd(filename):
    print("Writing %s" % filename)
    with open(filename, "w", encoding="utf8") as fh:
        # dict of dicts of hierarchy elements/signal name  -> (code, width)
        scopeSigs = {}
        # map from time -> code -> value
        values = collections.defaultdict(lambda: {})

        parallelism = {
            'measured': collections.defaultdict(lambda: 0),
            'predicted': collections.defaultdict(lambda: 0)
        }
        parallelism['measured'][0] = 0
        parallelism['predicted'][0] = 0

        nextFreeCode = 0

        def getCode(width, *names):
            nonlocal nextFreeCode
            scope = scopeSigs
            *path, name = names
            for item in path:
                scope = scope.setdefault(item, {})
            code, oldWidth = scope.setdefault(name, (nextFreeCode, width))
            assert oldWidth == width
            if code == nextFreeCode:
                nextFreeCode += 1
            return code

        def addValue(code, time, val):
            if isinstance(val, str):
                buf = ""
                for c in val:
                    buf += bin(ord(c))[2:].rjust(8, "0")
                val = buf.ljust(LongestVcdStrValueLength * 8, "0")
            values[time][code] = val

        # Measured graph
        for thread in sorted(Threads):
            mcode = getCode(32, 'measured', 't%d_mtask' % thread)
            for record in Threads[thread]:
                start = record['start']
                mtask = record['mtask']
                end = record['end']
                cpu = record['cpu']
                addValue(mcode, start, mtask)
                addValue(mcode, end, None)
                parallelism['measured'][start] += 1
                parallelism['measured'][end] -= 1

                code = getCode(32, 'measured', 'cpu%d_thread' % cpu)
                addValue(code, start, thread)
                addValue(code, end, None)

                code = getCode(32, 'measured', 'cpu%d_mtask' % cpu)
                addValue(code, start, mtask)
                addValue(code, end, None)

        tStart = sorted(_['start'] for records in Threads.values() for _ in records)
        tEnd = sorted(_['end'] for records in Threads.values() for _ in records)

        # Predicted graph
        for start, end in ExecGraphIntervals:
            # Find the earliest MTask start after the start point, and the
            # latest MTask end before the end point, so we can scale to the
            # same range
            start = tStart[bisect.bisect_left(tStart, start)]
            end = tEnd[bisect.bisect_right(tEnd, end) - 1]
            # Compute scale so predicted graph is of same width as interval
            measured_scaling = (end - start) / Global['predict_last_end']
            # Predict mtasks that fill the time the execution occupied
            for mtask in Mtasks:
                thread = Mtasks[mtask]['thread']
                pred_scaled_start = start + int(Mtasks[mtask]['predict_start'] * measured_scaling)
                pred_scaled_end = start + int(
                    (Mtasks[mtask]['predict_start'] + Mtasks[mtask]['predict_cost']) *
                    measured_scaling)
                if pred_scaled_start == pred_scaled_end:
                    continue

                mcode = getCode(32, 'predicted', 't%d_mtask' % thread)
                addValue(mcode, pred_scaled_start, mtask)
                addValue(mcode, pred_scaled_end, None)

                parallelism['predicted'][pred_scaled_start] += 1
                parallelism['predicted'][pred_scaled_end] -= 1

        # Parallelism graph
        for measpred in ('measured', 'predicted'):
            pcode = getCode(32, 'stats', '%s_parallelism' % measpred)
            value = 0
            for time in sorted(parallelism[measpred].keys()):
                value += parallelism[measpred][time]
                addValue(pcode, time, value)

        # Section graph
        if Sections:
            scode = getCode(LongestVcdStrValueLength * 8, "section", "trace")
            dcode = getCode(32, "section", "depth")
            for time, stack in Sections:
                addValue(scode, time, stack[-1] if stack else None)
                addValue(dcode, time, len(stack))

        # Create output file
        fh.write("$version Generated by verilator_gantt $end\n")
        fh.write("$timescale 1ns $end\n")
        fh.write("\n")

        all_codes = set()

        def writeScope(scope):
            assert isinstance(scope, dict)
            for key in sorted(scope):
                val = scope[key]
                if isinstance(val, dict):
                    fh.write("  $scope module %s $end\n" % key)
                    writeScope(val)
                    fh.write("  $upscope $end\n")
                else:
                    (code, width) = val
                    all_codes.add(code)
                    fh.write("   $var wire %d v%x %s [%d:0] $end\n" %
                             (width, code, key, width - 1))

        fh.write(" $scope module gantt $end\n")
        writeScope(scopeSigs)
        fh.write(" $upscope $end\n")
        fh.write("$enddefinitions $end\n")
        fh.write("\n")

        first = True
        for time in sorted(values):
            if first:
                first = False
                # Start with Z for any signals without time zero data
                for code in sorted(all_codes):
                    if code not in values[time]:
                        addValue(code, time, None)
            fh.write("#%d\n" % time)
            for code in sorted(values[time]):
                value = values[time][code]
                if value is None:
                    fh.write("bz v%x\n" % code)
                elif isinstance(value, str):
                    fh.write("b%s v%x\n" % (value, code))
                else:
                    fh.write("b%s v%x\n" % (format(value, 'b'), code))


######################################################################

parser = argparse.ArgumentParser(
    allow_abbrev=False,
    formatter_class=argparse.RawDescriptionHelpFormatter,
    description="""Create Gantt chart of multi-threaded execution.

Verilator_gantt creates a visual representation to help analyze Verilator
#xmultithreaded simulation performance, by showing when each macro-task
#xstarts and ends, and showing when each thread is busy or idle.

For documentation see
https://verilator.org/guide/latest/exe_verilator_gantt.html""",
    epilog="""Copyright 2018-2024 by Wilson Snyder. This program is free software; you
can redistribute it and/or modify it under the terms of either the GNU
Lesser General Public License Version 3 or the Perl Artistic License
Version 2.0.

SPDX-License-Identifier: LGPL-3.0-only OR Artistic-2.0""")

parser.add_argument('--debug', action='store_true', help='enable debug')
parser.add_argument('--no-vcd', help='disable creating vcd', action='store_true')
parser.add_argument('--vcd', help='filename for vcd outpue', default='profile_exec.vcd')
parser.add_argument('filename',
                    help='input profile_exec.dat filename to process',
                    default='profile_exec.dat')

Args = parser.parse_args()

read_data(Args.filename)
report()
if not Args.no_vcd:
    write_vcd(Args.vcd)

######################################################################
# Local Variables:
# compile-command: "./verilator_gantt ../test_regress/t/t_gantt_io.dat"
# End:
