# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from libc.stdint cimport uintptr_t, int64_t, uint64_t

from cuda.bindings cimport cydriver
from cuda.core.experimental._memory._device_memory_resource cimport DeviceMemoryResource
from cuda.core.experimental._memory._ipc cimport IPCBufferDescriptor, IPCDataForBuffer
from cuda.core.experimental._memory cimport _ipc
from cuda.core.experimental._stream cimport Stream_accept, Stream
from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN

import abc
from typing import TypeVar, Union

from cuda.core.experimental._dlpack import DLDeviceType, make_py_capsule
from cuda.core.experimental._utils.cuda_utils import driver

__all__ = ['Buffer', 'MemoryResource']


DevicePointerT = Union[driver.CUdeviceptr, int, None]
"""
A type union of :obj:`~driver.CUdeviceptr`, `int` and `None` for hinting
:attr:`Buffer.handle`.
"""

cdef class Buffer:
    """Represent a handle to allocated memory.

    This generic object provides a unified representation for how
    different memory resources are to give access to their memory
    allocations.

    Support for data interchange mechanisms are provided by DLPack.
    """
    def __cinit__(self):
        self._clear()

    def _clear(self):
        self._ptr = 0
        self._size = 0
        self._memory_resource = None
        self._ipc_data = None
        self._ptr_obj = None
        self._alloc_stream = None

    def __init__(self, *args, **kwargs):
        raise RuntimeError("Buffer objects cannot be instantiated directly. "
                           "Please use MemoryResource APIs.")

    @classmethod
    def _init(
        cls, ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None,
        stream: Stream | None = None, ipc_descriptor: IPCBufferDescriptor | None = None
    ):
        cdef Buffer self = Buffer.__new__(cls)
        self._ptr = <uintptr_t>(int(ptr))
        self._ptr_obj = ptr
        self._size = size
        self._memory_resource = mr
        self._ipc_data = IPCDataForBuffer(ipc_descriptor, True) if ipc_descriptor is not None else None
        self._alloc_stream = <Stream>(stream) if stream is not None else None
        return self

    def __dealloc__(self):
        self.close(self._alloc_stream)

    def __reduce__(self):
        # Must not serialize the parent's stream!
        return Buffer.from_ipc_descriptor, (self.memory_resource, self.get_ipc_descriptor())

    @staticmethod
    def from_handle(
        ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None
    ) -> Buffer:
        """Create a new :class:`Buffer` object from a pointer.

        Parameters
        ----------
        ptr : :obj:`~_memory.DevicePointerT`
            Allocated buffer handle object
        size : int
            Memory size of the buffer
        mr : :obj:`~_memory.MemoryResource`, optional
            Memory resource associated with the buffer
        """
        # TODO: It is better to take a stream for latter deallocation
        return Buffer._init(ptr, size, mr=mr)

    @classmethod
    def from_ipc_descriptor(
        cls, mr: DeviceMemoryResource, ipc_descriptor: IPCBufferDescriptor,
        stream: Stream = None
    ) -> Buffer:
        """Import a buffer that was exported from another process."""
        return _ipc.Buffer_from_ipc_descriptor(cls, mr, ipc_descriptor, stream)

    def get_ipc_descriptor(self) -> IPCBufferDescriptor:
        """Export a buffer allocated for sharing between processes."""
        if self._ipc_data is None:
            self._ipc_data = IPCDataForBuffer(_ipc.Buffer_get_ipc_descriptor(self), False)
        return self._ipc_data.ipc_descriptor

    def close(self, stream: Stream | GraphBuilder | None = None):
        """Deallocate this buffer asynchronously on the given stream.

        This buffer is released back to their memory resource
        asynchronously on the given stream.

        Parameters
        ----------
        stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`, optional
            The stream object to use for asynchronous deallocation. If None,
            the behavior depends on the underlying memory resource.
        """
        Buffer_close(self, stream)

    def copy_to(self, dst: Buffer = None, *, stream: Stream | GraphBuilder) -> Buffer:
        """Copy from this buffer to the dst buffer asynchronously on the given stream.

        Copies the data from this buffer to the provided dst buffer.
        If the dst buffer is not provided, then a new buffer is first
        allocated using the associated memory resource before the copy.

        Parameters
        ----------
        dst : :obj:`~_memory.Buffer`
            Source buffer to copy data from
        stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
            Keyword argument specifying the stream for the
            asynchronous copy

        """
        stream = Stream_accept(stream)
        cdef Stream s_stream = <Stream>stream
        cdef size_t src_size = self._size

        if dst is None:
            if self._memory_resource is None:
                raise ValueError("a destination buffer must be provided (this "
                                 "buffer does not have a memory_resource)")
            dst = self._memory_resource.allocate(src_size, stream)

        cdef size_t dst_size = dst._size
        if dst_size != src_size:
            raise ValueError( "buffer sizes mismatch between src and dst (sizes "
                             f"are: src={src_size}, dst={dst_size})"
            )
        cdef cydriver.CUstream s = s_stream._handle
        with nogil:
            HANDLE_RETURN(cydriver.cuMemcpyAsync(
                <cydriver.CUdeviceptr>dst._ptr,
                <cydriver.CUdeviceptr>self._ptr,
                src_size,
                s
            ))
        return dst

    def copy_from(self, src: Buffer, *, stream: Stream | GraphBuilder):
        """Copy from the src buffer to this buffer asynchronously on the given stream.

        Parameters
        ----------
        src : :obj:`~_memory.Buffer`
            Source buffer to copy data from
        stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
            Keyword argument specifying the stream for the
            asynchronous copy

        """
        stream = Stream_accept(stream)
        cdef Stream s_stream = <Stream>stream
        cdef size_t dst_size = self._size
        cdef size_t src_size = src._size

        if src_size != dst_size:
            raise ValueError( "buffer sizes mismatch between src and dst (sizes "
                             f"are: src={src_size}, dst={dst_size})"
            )
        cdef cydriver.CUstream s = s_stream._handle
        with nogil:
            HANDLE_RETURN(cydriver.cuMemcpyAsync(
                <cydriver.CUdeviceptr>self._ptr,
                <cydriver.CUdeviceptr>src._ptr,
                dst_size,
                s
            ))

    def fill(self, value: int, width: int, *, stream: Stream | GraphBuilder):
        """Fill this buffer with a value pattern asynchronously on the given stream.

        Parameters
        ----------
        value : int
            Integer value to fill the buffer with
        width : int
            Width in bytes for each element (must be 1, 2, or 4)
        stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
            Keyword argument specifying the stream for the asynchronous fill

        Raises
        ------
        ValueError
            If width is not 1, 2, or 4, if value is out of range for the width,
            or if buffer size is not divisible by width

        """
        cdef Stream s_stream = Stream_accept(stream)
        cdef unsigned char c_value8
        cdef unsigned short c_value16
        cdef unsigned int c_value32
        cdef size_t N

        # Validate width
        if width not in (1, 2, 4):
            raise ValueError(f"width must be 1, 2, or 4, got {width}")

        # Validate buffer size modulus.
        cdef size_t buffer_size = self._size
        if buffer_size % width != 0:
            raise ValueError(f"buffer size ({buffer_size}) must be divisible by width ({width})")

        # Map width (bytes) to bitwidth and validate value
        cdef int bitwidth = width * 8
        _validate_value_against_bitwidth(bitwidth, value, is_signed=False)

        # Validate value fits in width and perform fill
        cdef cydriver.CUstream s = s_stream._handle
        if width == 1:
            c_value8 = <unsigned char>value
            N = buffer_size
            with nogil:
                HANDLE_RETURN(cydriver.cuMemsetD8Async(<cydriver.CUdeviceptr>self._ptr, c_value8, N, s))
        elif width == 2:
            c_value16 = <unsigned short>value
            N = buffer_size // 2
            with nogil:
                HANDLE_RETURN(cydriver.cuMemsetD16Async(<cydriver.CUdeviceptr>self._ptr, c_value16, N, s))
        else:  # width == 4
            c_value32 = <unsigned int>value
            N = buffer_size // 4
            with nogil:
                HANDLE_RETURN(cydriver.cuMemsetD32Async(<cydriver.CUdeviceptr>self._ptr, c_value32, N, s))

    def __dlpack__(
        self,
        *,
        stream: int | None = None,
        max_version: tuple[int, int] | None = None,
        dl_device: tuple[int, int] | None = None,
        copy: bool | None = None,
    ) -> TypeVar("PyCapsule"):
        # Note: we ignore the stream argument entirely (as if it is -1).
        # It is the user's responsibility to maintain stream order.
        if dl_device is not None:
            raise BufferError("Sorry, not supported: dl_device other than None")
        if copy is True:
            raise BufferError("Sorry, not supported: copy=True")
        if max_version is None:
            versioned = False
        else:
            if not isinstance(max_version, tuple) or len(max_version) != 2:
                raise BufferError(f"Expected max_version tuple[int, int], got {max_version}")
            versioned = max_version >= (1, 0)
        capsule = make_py_capsule(self, versioned)
        return capsule

    def __dlpack_device__(self) -> tuple[int, int]:
        cdef bint d = self.is_device_accessible
        cdef bint h = self.is_host_accessible
        if d and (not h):
            return (DLDeviceType.kDLCUDA, self.device_id)
        if d and h:
            # TODO: this can also be kDLCUDAManaged, we need more fine-grained checks
            return (DLDeviceType.kDLCUDAHost, 0)
        if (not d) and h:
            return (DLDeviceType.kDLCPU, 0)
        raise BufferError("buffer is neither device-accessible nor host-accessible")

    def __buffer__(self, flags: int, /) -> memoryview:
        # Support for Python-level buffer protocol as per PEP 688.
        # This raises a BufferError unless:
        #   1. Python is 3.12+
        #   2. This Buffer object is host accessible
        raise NotImplementedError("WIP: Buffer.__buffer__ hasn't been implemented yet.")

    def __release_buffer__(self, buffer: memoryview, /):
        # Supporting method paired with __buffer__.
        raise NotImplementedError("WIP: Buffer.__release_buffer__ hasn't been implemented yet.")

    @property
    def device_id(self) -> int:
        """Return the device ordinal of this buffer."""
        if self._memory_resource is not None:
            return self._memory_resource.device_id
        raise NotImplementedError("WIP: Currently this property only supports buffers with associated MemoryResource")

    @property
    def handle(self) -> DevicePointerT:
        """Return the buffer handle object.

        .. caution::

            This handle is a Python object. To get the memory address of the underlying C
            handle, call ``int(Buffer.handle)``.
        """
        if self._ptr_obj is not None:
            return self._ptr_obj
        elif self._ptr:
            return self._ptr
        else:
            # contract: Buffer is closed
            return 0

    @property
    def is_device_accessible(self) -> bool:
        """Return True if this buffer can be accessed by the GPU, otherwise False."""
        if self._memory_resource is not None:
            return self._memory_resource.is_device_accessible
        raise NotImplementedError("WIP: Currently this property only supports buffers with associated MemoryResource")

    @property
    def is_host_accessible(self) -> bool:
        """Return True if this buffer can be accessed by the CPU, otherwise False."""
        if self._memory_resource is not None:
            return self._memory_resource.is_host_accessible
        raise NotImplementedError("WIP: Currently this property only supports buffers with associated MemoryResource")

    @property
    def is_mapped(self) -> bool:
        """Return True if this buffer is mapped into the process via IPC."""
        return getattr(self._ipc_data, "is_mapped", False)


    @property
    def memory_resource(self) -> MemoryResource:
        """Return the memory resource associated with this buffer."""
        return self._memory_resource

    @property
    def size(self) -> int:
        """Return the memory size of this buffer."""
        return self._size


# Buffer Implementation
# ---------------------
cdef inline void Buffer_close(Buffer self, stream):
    cdef Stream s
    if self._ptr and self._memory_resource is not None:
        s = Stream_accept(stream) if stream is not None else self._alloc_stream
        self._memory_resource.deallocate(self._ptr, self._size, s)
        self._ptr = 0
        self._memory_resource = None
        self._ptr_obj = None
        self._alloc_stream = None


cdef class MemoryResource:
    """Abstract base class for memory resources that manage allocation and
    deallocation of buffers.

    Subclasses must implement methods for allocating and deallocation, as well
    as properties associated with this memory resource from which all allocated
    buffers will inherit. (Since all :class:`Buffer` instances allocated and
    returned by the :meth:`allocate` method would hold a reference to self, the
    buffer properties are retrieved simply by looking up the underlying memory
    resource's respective property.)
    """

    @abc.abstractmethod
    def allocate(self, size_t size, stream: Stream | GraphBuilder | None = None) -> Buffer:
        """Allocate a buffer of the requested size.

        Parameters
        ----------
        size : int
            The size of the buffer to allocate, in bytes.
        stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`, optional
            The stream on which to perform the allocation asynchronously.
            If None, it is up to each memory resource implementation to decide
            and document the behavior.

        Returns
        -------
        Buffer
            The allocated buffer object, which can be used for device or host operations
            depending on the resource's properties.
        """
        ...

    @abc.abstractmethod
    def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream | GraphBuilder | None = None):
        """Deallocate a buffer previously allocated by this resource.

        Parameters
        ----------
        ptr : :obj:`~_memory.DevicePointerT`
            The pointer or handle to the buffer to deallocate.
        size : int
            The size of the buffer to deallocate, in bytes.
        stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`, optional
            The stream on which to perform the deallocation asynchronously.
            If None, it is up to each memory resource implementation to decide
            and document the behavior.
        """
        ...


# Helper Functions
# ----------------
cdef void _validate_value_against_bitwidth(int bitwidth, int64_t value, bint is_signed=False) except *:
    """Validate that a value fits within the representable range for a given bitwidth.

    Parameters
    ----------
    bitwidth : int
        Number of bits (e.g., 8, 16, 32)
    value : int64_t
        Value to validate
    is_signed : bool, optional
        Whether the value is signed (default: False)

    Raises
    ------
    ValueError
        If value is outside the representable range for the bitwidth
    """
    cdef int max_bits = bitwidth
    assert max_bits < 64, f"bitwidth ({max_bits}) must be less than 64"

    cdef int64_t min_value
    cdef uint64_t max_value_unsigned
    cdef int64_t max_value

    if is_signed:
        min_value = -(<int64_t>1 << (max_bits - 1))
        max_value = (<int64_t>1 << (max_bits - 1)) - 1
    else:
        min_value = 0
        max_value_unsigned = (<uint64_t>1 << max_bits) - 1
        max_value = <int64_t>max_value_unsigned

    if not min_value <= value <= max_value:
        raise ValueError(
            f"value must be in range [{min_value}, {max_value}]"
        )
