# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import math
from typing import Any, Optional, Union, overload

import numpy as np
import pandas as pd


def get_bin_size(bin_number: int, max_duration: int) -> int:
    return math.ceil(max_duration / bin_number)


def generate_bin_list(
    bin_number: int, bin_size: int, include_last: bool = False
) -> list[int]:
    """Generates a list of values determining the boundaries of each bin.

    Parameters
    ----------
    bin_number : int
        Total number of bins in the generated list.
    bin_size : int
        Size of each bin.
    include_last : bool
        Whether the maximum bin boundary should be included or not.
    """
    bin_count = bin_number + 1 if include_last else bin_number
    return [bin_size * i for i in range(bin_count)]


@overload
def _calculate_bin_info(
    starts: int, ends: int, bin_size: int
) -> tuple[int, int, float, float]: ...


@overload
def _calculate_bin_info(
    starts: np.ndarray, ends: np.ndarray, bin_size: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ...


@overload
def _calculate_bin_info(
    starts: pd.Series, ends: pd.Series, bin_size: int
) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]: ...


def _calculate_bin_info(
    starts: Union[int, np.ndarray, pd.Series],
    ends: Union[int, np.ndarray, pd.Series],
    bin_size: int,
):
    # Scale the start and end times by the bin size.
    start_scaled = starts / bin_size
    end_scaled = ends / bin_size

    # Calculate the bin index for each start and end.
    start_bins = np.floor(start_scaled).astype(int)
    end_bins = np.floor(end_scaled).astype(int)

    # Calculate the clipped start and end values to ensure they don't exceed
    # the bin boundaries.
    ends = np.minimum(np.ceil(start_scaled), end_scaled)
    starts = np.maximum(np.floor(end_scaled), start_scaled)

    # Calculate the coverage percentage.
    start_percents = ends - start_scaled
    end_percents = end_scaled - starts

    return start_bins, end_bins, start_percents, end_percents


def _rectify_pct_inplace(
    bin_pcts: np.ndarray, bin_size: int, max_duration: int, session_offset: int
) -> None:
    # Any portion of the bins that is outside the profiling session will be
    # removed.
    start_bin, end_bin, start_percent, end_percent = _calculate_bin_info(
        session_offset, max_duration, bin_size
    )

    if start_percent != 0:
        bin_pcts[start_bin] /= start_percent
    if end_percent != 0 and start_bin != end_bin:
        bin_pcts[end_bin] /= end_percent

    # Set values outside the profiling session to NaN.
    bin_pcts[end_bin + 1 :] = np.nan
    bin_pcts[:start_bin] = np.nan


def get_zero_bin_pcts(
    bin_size: int, bin_num: int, max_duration: int, session_offset: int
) -> np.ndarray:
    """Fill each bin with zero."""
    bin_pcts = np.zeros(bin_num)
    _rectify_pct_inplace(bin_pcts, bin_size, max_duration, session_offset)
    return bin_pcts


def calculate_bin_pcts(
    df: pd.DataFrame,
    bin_size: int,
    bin_num: int,
    max_duration: int,
    session_offset: int,
    value_key: Optional[str] = None,
) -> np.ndarray:
    """Calculate the percentage for each bin."""
    values = df[value_key] if value_key else pd.Series(np.ones(df.shape[0]))

    start_bins, end_bins, start_percents, end_percents = _calculate_bin_info(
        df["start"], df["end"], bin_size
    )

    # Handle cases where the range falls in a single bin.
    # In this case, either the start or the end arrays can be used.
    single_bin_indices = np.where(end_bins == start_bins)[0]
    bin_pcts = np.bincount(
        start_bins.iloc[single_bin_indices],
        weights=start_percents.iloc[single_bin_indices]
        * values.iloc[single_bin_indices],
        minlength=bin_num,
    ).astype(float)

    # Handle cases where the range spans multiple bins.
    # We add the percentages individually for each bin.
    multi_bin_indices = np.where(end_bins != start_bins)[0]
    for i in multi_bin_indices:
        bin_pcts[end_bins.iloc[i]] += end_percents.iloc[i] * values.iloc[i]
        bin_pcts[start_bins.iloc[i]] += start_percents.iloc[i] * values.iloc[i]
        bin_pcts[start_bins.iloc[i] + 1 : end_bins.iloc[i]] += values.iloc[i]

    _rectify_pct_inplace(bin_pcts, bin_size, max_duration, session_offset)

    return (bin_pcts * 100).round(1)
