# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import ast
from collections.abc import Iterator
from typing import Any, Optional, cast

import numpy as np
import pandas as pd

from nsys_recipe.lib import arm_metrics as am
from nsys_recipe.lib import exceptions, overlap
from nsys_recipe.lib.typing_helpers import _PandasNamedTuple


class Architecture:
    UNKNOWN = "unknown"
    AARCH64_SBSA = "aarch64_sbsa"
    AARCH64_TEGRA = "aarch64_tegra"
    X86_64 = "x86_64"


THREAD_STATE_TERMINATED = 5
THREAD_STATE_UNKNOWN = 0


class SchedEventError(exceptions.ValueError):
    pass


def get_cpu_activity_ranges(thread_sched_df: pd.DataFrame) -> pd.DataFrame:
    """
    Compute an array of CPU activity ranges, where such a range means that
    in a particular CPU, a particular thread from a particular process
    was scheduled for that time.

    Parameters
    ----------
    thread_sched_df : pd.DataFrame
        DataFrame containing thread scheduling events from the SQL table
        SCHED_EVENTS.

    Returns
    -------
    pd.DataFrame
        DataFrame containing CPU activity ranges with columns:
        cpu, start, end, tid, pid.
    """
    if len(thread_sched_df) < 2:
        raise SchedEventError(
            "The thread scheduling events table does not contain enough "
            + "events to form CPU activity."
        )

    def get_missing_event_msg(row: _PandasNamedTuple, sched_in_missing: bool) -> str:
        return (
            f"The scheduling {'in' if sched_in_missing else 'out'} "
            + f"event is missing for TID {row.tid} and PID {row.pid} "
            + f"on CPU {row.cpu}, the timestamp of the scheduling "
            + f"{'out' if sched_in_missing else 'in'} event: {row.start}ns."
        )

    def get_missing_sched_in_event_msg(row: _PandasNamedTuple) -> str:
        return get_missing_event_msg(row, True)

    def get_missing_sched_out_event_msg(row: _PandasNamedTuple) -> str:
        return get_missing_event_msg(row, False)

    def get_gtid_mismatch_msg(prev_row: pd.Series, row: pd.Series) -> str:
        return (
            f"Thread scheduling data is missing. On CPU {prev_row.cpu} "
            + f"at time {prev_row.start}ns, TID {prev_row.tid} was scheduled in "
            + f"but not scheduled out later. At time {row.start}ns, TID {row.tid} was "
            + f"scheduled out but never scheduled in."
        )

    def get_cpu_mismatch_msg(prev_row: pd.Series, row: pd.Series) -> str:
        return (
            f"The scheduling out event for TID {prev_row.tid} and PID {prev_row.pid} "
            + f"occurred on a different CPU than the scheduling in event. "
            + f"The thread was scheduled in on CPU {prev_row.cpu} and "
            + f"scheduled out on CPU {row.cpu}, "
            + f"scheduling in time: {prev_row.start}ns, scheduling out time: {row.start}ns."
        )

    def get_mismatch_msg(prev_row: pd.Series, row: pd.Series, key: str) -> str:
        if key == "globalTid":
            return get_gtid_mismatch_msg(prev_row, row)
        if key == "cpu":
            return get_cpu_mismatch_msg(prev_row, row)
        raise SchedEventError(f"Unknown key: {key}")

    def get_two_events_same_time_msg(row: pd.Series, key: str) -> str:
        start = "Two scheduling events "
        if key == "globalTid":
            start += f"for TID {row.tid} and PID {row.pid}"
        elif key == "cpu":
            start += f"on CPU {row.cpu}"
        else:
            raise SchedEventError(f"Unknown key: {key}")
        return start + f" happened at the same time: {row.start}ns."

    # We currently don't throw a SchedEventError when:
    #   1. The last event for the thread is a scheduling-in event, as it is
    #      possible that the thread was scheduled out after the profiling was
    #      stopped.
    #   2. The first event for the thread is a scheduling-out event, as it is
    #      possible that the thread was scheduled in before the profiling was
    #      started.
    # TODO: Update this when DTSP-19672 is implemented, which will handle the 1st case.

    def iterate_over_sched_events(
        sched_df: pd.DataFrame,
        group_key: str,
        comp_key: str,
        process_fn: Optional[Any] = None,
    ) -> None:

        def get_next_row(sched_df: pd.DataFrame, row: pd.Series) -> pd.Series:
            return cast(pd.Series, sched_df.loc[cast(np.int64, row.name) + 1])

        for _, sched_group in sched_df.groupby(group_key, sort=False):
            # Throw error for two events that happen at the same time
            same_time_mask = sched_group["start"] == sched_group["start"].shift(-1) # fmt: skip
            if same_time_mask.any():
                row = sched_group[same_time_mask].iloc[0]
                raise SchedEventError(get_two_events_same_time_msg(row, group_key))

            same_comp_key_mask = \
                sched_group[comp_key] == sched_group[comp_key].shift(-1) # fmt: skip

            sched_in_out_mask = \
                (sched_group["isSchedIn"] == True) & \
                (sched_group["isSchedIn"].shift(-1) == False) # fmt: skip

            # Throw error for two events that both are scheduling-in or scheduling-out
            same_sched_mask = \
                sched_group["isSchedIn"] == sched_group["isSchedIn"].shift(-1) # fmt: skip
            # NSys can generate two scheduling-out events in a row for the same
            # TID, PID and CPU when the target thread terminates.
            # These will be two scheduling-out events with different thread states:
            # the first with the TERMINATED thread state and the second with
            # the UNKNOWN thread state.
            # We allow such cases to pass without raising an error.
            allowed_same_sched_mask = (
                same_comp_key_mask
                & (sched_group["isSchedIn"] == False)
                & (sched_group["isSchedIn"].shift(-1) == False)
                & (sched_group["threadState"] == THREAD_STATE_TERMINATED)
                & (sched_group["threadState"].shift(-1) == THREAD_STATE_UNKNOWN)
            )
            same_sched_mask &= ~allowed_same_sched_mask
            if same_sched_mask.any():
                row = sched_group[same_sched_mask].iloc[0]
                if row["isSchedIn"]:
                    raise SchedEventError(get_missing_sched_out_event_msg(row))
                else:
                    next_row = get_next_row(sched_df, row)
                    raise SchedEventError(get_missing_sched_in_event_msg(next_row))

            # Throw error for mismatched comp_key (e.g., TID or CPU)
            comp_key_mismatch_mask = sched_in_out_mask & ~same_comp_key_mask
            if comp_key_mismatch_mask.any():
                row = sched_group[comp_key_mismatch_mask].iloc[0]
                next_row = get_next_row(sched_df, row)
                raise SchedEventError(get_mismatch_msg(row, next_row, comp_key))

            # Process valid pairs (sched-in, sched-out)
            if process_fn is not None:
                sched_pair_mask = sched_in_out_mask & same_comp_key_mask

                sched_pairs = sched_group[sched_pair_mask]
                if sched_pairs.empty:
                    return
                sched_pairs = sched_pairs.assign(end=sched_group["start"].shift(-1))
                sched_pairs["end"] = sched_pairs["end"].astype("int64")
                process_fn(sched_pairs)

    thread_sched_df = thread_sched_df.sort_values(by="start")

    # Iterate over the scheduling events per CPU
    iterate_over_sched_events(thread_sched_df, "cpu", "globalTid")

    cpu_ranges = []
    # Iterate over the scheduling events per global TID
    iterate_over_sched_events(
        thread_sched_df, "globalTid", "cpu", lambda df: cpu_ranges.append(df)
    )

    cpu_activity_columns = ["cpu", "start", "end", "tid", "pid"]

    if not cpu_ranges:
        return pd.DataFrame(columns=cpu_activity_columns)
    cpu_df = (
        pd.concat(cpu_ranges)
        .sort_values(by=["cpu", "start"])
        .reset_index(drop=True)[cpu_activity_columns]
    )
    return cpu_df


def compute_cpu_time(
    ranges_df: pd.DataFrame,
    cpu_df: pd.DataFrame,
    rely_on_tid: bool = True,
) -> pd.Series:
    """
    Calculate the active CPU time for the given ranges.

    Parameters
    ----------
    ranges_df : pd.DataFrame
        DataFrame containing ranges of interest with columns:
        start, end, globalTid, tid, pid.
    cpu_df : pd.DataFrame
        DataFrame containing CPU activity ranges with columns:
        cpu, start, end, tid, pid.
    rely_on_tid : bool = True
        If True, the function will only consider time on CPUs
        on which the thread of a particular range was running.

        If False, the function will only consider time on CPUs
        on which the process of a particular range was running (data from all
        threads active while the range was running will be included).

    Returns
    -------
     pd.Series
        Series containing active CPU times with a row corresponding
        to the range.
    """
    if ranges_df.empty or cpu_df.empty:
        return pd.Series(dtype=int)

    cpu_times = []
    for range in ranges_df.itertuples():
        # 1. Attribute CPU activity ranges to the range of interest by time
        #    overlap, PID and TID (if rely_on_tid=True) match.
        range_cpu_mask = (
            (cpu_df["start"] <= range.end)
            & (cpu_df["end"] >= range.start)
            & (cpu_df["pid"] == range.pid)
        )
        if rely_on_tid:
            range_cpu_mask &= cpu_df["tid"] == range.tid
        range_cpu_df = cpu_df[range_cpu_mask]

        # 2. Calculate the active CPU time for the range.
        #    The active CPU time is the sum of the time spent on each CPU
        #    that the thread was running on during the range.
        #    The time spent on each CPU is calculated as the intersection of
        #    the CPU activity range and the range of interest.
        start_timestamps = range_cpu_df["start"].clip(lower=cast(int, range.start))
        end_timestamps = range_cpu_df["end"].clip(upper=cast(int, range.end))
        durations = end_timestamps - start_timestamps

        cpu_times.append(durations.sum())

    return pd.Series(cpu_times, index=ranges_df.index)


def _get_perf_event_scaler_for_range(
    perf_event_start: int, perf_event_end: int, range_start: int, range_end: int
) -> float:
    """
    Calculate the scaler for a perf event based on the range of interest.
    Parameters
    ----------
    perf_event_start : int
        Start time of the perf event in ns.
    perf_event_end : int
        End time of the perf event in ns.
    range_start : int
        Start time of the range of interest in ns.
    range_end : int
        End time of the range of interest in ns.
    Returns
    -------
    float
        Scaler for the perf event based on the range of interest.
    """
    count_scaler = 0.0
    if perf_event_start >= range_start:
        if perf_event_end <= range_end:
            # |--------------------------------------------|
            #       |-----------Range-----------|
            #               |--Perf Event--|
            count_scaler = 1
        elif perf_event_start < range_end:
            # |--------------------------------------------|
            #       |-----------Range-----------|
            #                              |--Perf Event--|
            perf_event_time = perf_event_end - perf_event_start
            perf_event_in_range_time = range_end - perf_event_start
            count_scaler = perf_event_in_range_time / perf_event_time
    elif perf_event_end > range_start:
        perf_event_time = perf_event_end - perf_event_start
        if perf_event_end <= range_end:
            # |--------------------------------------------|
            #       |-----------Range-----------|
            # |--Perf Event--|
            perf_event_in_range_time = perf_event_end - range_start
            count_scaler = perf_event_in_range_time / perf_event_time
        else:
            # |--------------------------------------------|
            #                  |-Range-|
            #               |--Perf Event--|
            range_time = range_end - range_start
            count_scaler = range_time / perf_event_time

    return count_scaler


def compute_core_perf_events(
    ranges_df: pd.DataFrame,
    core_perf_df: pd.DataFrame,
    cpu_df: pd.DataFrame,
    rely_on_tid: bool = True,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Compute core perf events and their number of samples for the given ranges.

    Parameters
    ----------
    ranges_df : pd.DataFrame
        DataFrame containing ranges of interest with columns:
        start, end, globalTid, tid, pid.
    core_perf_df : pd.DataFrame
        DataFrame containing core perf events collected by nsys.
        Columns: start, end, cpu, name, count.
    cpu_df : pd.DataFrame
        DataFrame containing CPU activity ranges with columns:
        cpu, start, end, tid, pid.
    rely_on_tid : bool = True
        If True, the function will only consider events that occurred on CPUs
        on which the thread of a particular range was running.

        If False, the function will only consider events that occurred on CPUs
        on which the process of a particular range was running (data from all
        threads active while the range was running will be included).

    Returns
    -------
    A tuple of two elements:
    - perf_events : pd.DataFrame
        DataFrame containing core perf events collected by nsys with a row
        corresponding to the range.
    - perf_samples : pd.DataFrame
        DataFrame containing the number of samples for each perf event
        for each CPU with a row corresponding to the range.
    """
    if ranges_df.empty or core_perf_df.empty or cpu_df.empty:
        return pd.DataFrame(), pd.DataFrame()

    perf_events = []
    perf_samples = []
    perf_event_names = core_perf_df["name"].unique()

    # Unfortuantely there is no simple way to iterate more type-safely
    for range in cast(Iterator[Any], ranges_df.itertuples()):
        # 1. Filter perf events for the range of interest by time overlap.
        range_perf_df = core_perf_df[
            (core_perf_df["start"] <= range.end) & (core_perf_df["end"] >= range.start)
        ]

        # 2. Attribute CPU activity ranges to the range of interest by time overlap,
        #    PID and TID (if rely_on_tid=True) match.
        range_cpu_mask = (
            (cpu_df["start"] <= range.end)
            & (cpu_df["end"] >= range.start)
            & (cpu_df["pid"] == range.pid)
        )
        if rely_on_tid:
            range_cpu_mask &= cpu_df["tid"] == range.tid
        range_cpu_df = cpu_df[range_cpu_mask]

        # 3. Filter perf events by CPU activity ranges using map_overlapping_ranges().
        range_perf_gdf = range_perf_df.groupby("cpu")
        range_cpu_gdf = range_cpu_df.groupby("cpu")

        range_perf_indices = []
        range_cpu_indices = []
        for cpu, curr_cpu_df in range_cpu_gdf:
            if cpu not in range_perf_gdf.groups:
                continue

            curr_perf_df = range_perf_gdf.get_group(cpu)

            perf_to_cpu_idx_map = overlap.map_overlapping_ranges(
                curr_perf_df, curr_cpu_df, key_df="df1"
            )

            for perf_idx, cpu_indices in perf_to_cpu_idx_map.items():
                range_perf_indices.extend([perf_idx] * len(cpu_indices))
                range_cpu_indices.extend(cpu_indices)

        range_perf_df = range_perf_df.loc[range_perf_indices]
        range_cpu_df = range_cpu_df.loc[range_cpu_indices]

        range_perf_events = {x: 0 for x in perf_event_names}
        range_perf_samples = {
            f"{x}_{y}": 0.0
            for x in perf_event_names
            for y in range_cpu_df["cpu"].values
        }

        # Here the `_get_perf_event_scaler_for_range()` is called for each
        # portion of the `range` that happens on a particular CPU. E.g.,
        # for each case from a to f in the example below:
        #        |----------------------Timeline------------------------|
        #          |------------------Range (TID 1)------------------|
        #
        #             a             b                     c
        # CPU 1:   |-TID1-|      |-TID1-|              |-TID1-|
        # Perf:  |----PE----|----PE----|----PE----|----PE----|----PE----|
        #              1          2          3          4          5
        #
        #                    d                e                  f
        # CPU 2:          |-TID1-|      |----TID1------|      |-TID1-|
        # Perf:  |----PE----|----PE----|----PE----|----PE----|----PE----|
        #              1          2          3          4          5
        # Case a:
        # - For perf event 1 on CPU 1:
        #    count_scaler_a_1 = time(a) / time(PE 1).
        # Case b:
        # - For perf event 2 on CPU 1:
        #   count_scaler_b_2 = (end(PE 2) - start(b)) / time(PE 2)
        # - For perf event 3 on CPU 1:
        #   count_scaler_b_3 = (end(b) - start(PE 3)) / time(PE 3)
        # Case c:
        # - For perf event 4 on CPU 1:
        #   count_scaler_c_4 = (end(PE 4) - start(c)) / time(PE 4)
        # - For perf event 5 on CPU 1:
        #   count_scaler_c_5 = (end(c) - start(PE 5)) / time(PE 5)
        # Case d:
        # - For perf event 1 on CPU 2:
        #   count_scaler_d_1 = (end(PE 1) - start(d)) / time(PE 1)
        # - For perf event 2 on CPU 2:
        #   count_scaler_d_2 = (end(d) - start(PE 2)) / time(PE 2)
        # Case e:
        # - For perf event 3 on CPU 2:
        #   count_scaler_e_3 = (end(PE 3) - start(e)) / time(PE 3)
        # - For perf event 4 on CPU 2:
        #   count_scaler_e_4 = (end(e) - start(PE 4)) / time(PE 4)
        # Case f:
        # - For perf event 5 on CPU 2:
        #   count_scaler_f_5 = time(f) / time(PE 5)
        #
        # The total perf event value for the whole range
        # will be calculated as follows:
        # total value of perf event =
        #     count_scaler_a_1 * value(PE 1 on CPU 1) +
        #     count_scaler_b_2 * value(PE 2 on CPU 1) +
        #     count_scaler_b_3 * value(PE 3 on CPU 1) +
        #     count_scaler_c_4 * value(PE 4 on CPU 1) +
        #     count_scaler_c_5 * value(PE 5 on CPU 1) +
        #     count_scaler_d_1 * value(PE 1 on CPU 2) +
        #     count_scaler_d_2 * value(PE 2 on CPU 2) +
        #     count_scaler_e_3 * value(PE 3 on CPU 2) +
        #     count_scaler_e_4 * value(PE 4 on CPU 2) +
        #     count_scaler_f_5 * value(PE 5 on CPU 2)
        #
        # The total number of samples for the whole range
        # will be calculated per each CPU as follows:
        # total number of samples of perf event for CPU 1 =
        #     count_scaler_a_1 + count_scaler_b_2 + count_scaler_b_3 +
        #     count_scaler_c_4 + count_scaler_c_5
        # total number of samples of perf event for CPU 2 =
        #     count_scaler_d_1 + count_scaler_d_2 + count_scaler_e_3 +
        #     count_scaler_e_4 + count_scaler_f_5

        for perf_event, cpu_activity_range in zip(
            range_perf_df.itertuples(), range_cpu_df.itertuples()
        ):
            # Calculate the start and end times of the range portion
            # that does the actual work on the CPU.
            range_start = max(range.start, cpu_activity_range.start)
            range_end = min(range.end, cpu_activity_range.end)

            count_scaler = _get_perf_event_scaler_for_range(
                perf_event.start, perf_event.end, range_start, range_end
            )

            range_perf_events[perf_event.name] += round(count_scaler * perf_event.count)
            key = f"{perf_event.name}_{cpu_activity_range.cpu}"
            range_perf_samples[key] += count_scaler

        range_perf_events_df = pd.DataFrame(range_perf_events, index=[range.Index])
        range_perf_samples_df = pd.DataFrame(range_perf_samples, index=[range.Index])

        perf_events.append(range_perf_events_df)
        perf_samples.append(range_perf_samples_df)

    return pd.concat(perf_events), pd.concat(perf_samples)


class Equation:
    class DataExtractor(ast.NodeVisitor):
        def __init__(self):
            self._operands = []

        def visit_Name(self, node: ast.Name) -> ast.Name:
            self._operands.append(node.id)
            return node

        @property
        def operands(self) -> list[str]:
            return self._operands

    def __init__(self, equation_str: str):
        self._ast = ast.parse(equation_str, mode="eval")
        data_extractor = Equation.DataExtractor()
        data_extractor.visit(self._ast)
        self._operands = data_extractor.operands

    def run(self, df: pd.DataFrame) -> Optional[np.ndarray]:
        context = {}
        for operand in self._operands:
            if operand not in df.columns:
                return None
            context[operand] = df[operand].values

        try:
            res = eval(
                compile(self._ast, filename="", mode="eval"),
                {"__builtins__": None},
                context,
            )
        except ZeroDivisionError as e:
            res = np.inf
        return res


def _parse_perf_metric_equations(
    metric_infos: list[am.PerfMetric],
) -> list[Optional[Equation]]:
    equations: list[Optional[Equation]] = [None] * len(metric_infos)
    for idx, info in enumerate(metric_infos):
        if info is not None:
            equations[idx] = Equation(info.equation)
    return equations


_arm_metric_equations: Optional[list[Optional[Equation]]] = None


def _get_metric_equations(cpu_arch: str) -> Optional[list[Optional[Equation]]]:
    global _arm_metric_equations
    if cpu_arch == Architecture.AARCH64_SBSA:
        if _arm_metric_equations is None:
            _arm_metric_equations = _parse_perf_metric_equations(am.get_arm_metrics())
        return _arm_metric_equations
    return None


def compute_perf_metrics(
    ranges_df: pd.DataFrame, time_column: str, cpu_arch: str
) -> pd.DataFrame:
    """
    Compute performance metrics for the provided ranges.
    Parameters
    ----------
    ranges_df : pd.DataFrame
        DataFrame containing ranges of interest with core perf events.
    time_column : str
        Column name in `ranges_df` that contains the time values.
    cpu_arch : str
        CPU architecture type, e.g., Architecture.AARCH64_SBSA.
    Returns
    -------
    pd.DataFrame
        DataFrame containing core performance metrics with a row corresponding
        to the range.
    """
    ranges_df = ranges_df.copy()
    ranges_df["TIME"] = ranges_df[time_column]

    df = pd.DataFrame(index=ranges_df.index)
    equations = _get_metric_equations(cpu_arch)
    if equations is None:
        return df
    for idx, equation in enumerate(equations):
        if equation is not None:
            id_name = am.PerfMetricType(idx).name
            res = equation.run(ranges_df)
            if res is not None:
                df[id_name] = res
    return df
