diff --git a/cuda_core/cuda/core/_include/aoti_shim.h b/cuda_core/cuda/core/_include/aoti_shim.h new file mode 100644 index 0000000000..26ebd1164c --- /dev/null +++ b/cuda_core/cuda/core/_include/aoti_shim.h @@ -0,0 +1,83 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Vendored subset of PyTorch's AOT Inductor (AOTI) stable C ABI. + * Original: torch/csrc/inductor/aoti_torch/c/shim.h + * + * These are declarations only -- no definitions are provided. The actual + * symbols are exported by libtorch (loaded via torch._C with RTLD_GLOBAL) + * and resolved at runtime by the dynamic linker. This means PyTorch is + * NOT required at compile time. + */ + +#ifndef CUDA_CORE_AOTI_SHIM_H +#define CUDA_CORE_AOTI_SHIM_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef int32_t AOTITorchError; + +/* Opaque tensor handle -- corresponds to at::Tensor on the C++ side. */ +struct AtenTensorOpaque; +typedef struct AtenTensorOpaque* AtenTensorHandle; + +/* ---- tensor metadata --------------------------------------------------- */ + +AOTITorchError aoti_torch_get_data_ptr( + AtenTensorHandle tensor, void** ret_data_ptr); + +AOTITorchError aoti_torch_get_dim( + AtenTensorHandle tensor, int64_t* ret_dim); + +AOTITorchError aoti_torch_get_sizes( + AtenTensorHandle tensor, int64_t** ret_sizes); + +AOTITorchError aoti_torch_get_strides( + AtenTensorHandle tensor, int64_t** ret_strides); + +/* ---- dtype ------------------------------------------------------------- */ + +AOTITorchError aoti_torch_get_dtype( + AtenTensorHandle tensor, int32_t* ret_dtype); + +int32_t aoti_torch_dtype_float16(void); +int32_t aoti_torch_dtype_float32(void); +int32_t aoti_torch_dtype_float64(void); +int32_t aoti_torch_dtype_bfloat16(void); +int32_t aoti_torch_dtype_uint8(void); +int32_t aoti_torch_dtype_int8(void); +int32_t aoti_torch_dtype_int16(void); +int32_t aoti_torch_dtype_int32(void); +int32_t aoti_torch_dtype_int64(void); +int32_t aoti_torch_dtype_bool(void); +int32_t aoti_torch_dtype_complex32(void); +int32_t aoti_torch_dtype_complex64(void); +int32_t aoti_torch_dtype_complex128(void); + +/* ---- device ------------------------------------------------------------ */ + +AOTITorchError aoti_torch_get_device_type( + AtenTensorHandle tensor, int32_t* ret_device_type); + +AOTITorchError aoti_torch_get_device_index( + AtenTensorHandle tensor, int32_t* ret_device_index); + +int32_t aoti_torch_device_type_cpu(void); +int32_t aoti_torch_device_type_cuda(void); + +/* ---- stream -------------------------------------------------------------- */ + +AOTITorchError aoti_torch_get_current_cuda_stream( + int32_t device_index, void** ret_stream); + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* CUDA_CORE_AOTI_SHIM_H */ diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index e0439ef23c..3ebde8dcff 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -10,7 +10,9 @@ from libc.stdint cimport intptr_t from cuda.core._layout cimport _StridedLayout, get_strides_ptr from cuda.core._stream import Stream +import ctypes import functools +import sys import warnings import numpy @@ -29,6 +31,73 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN from cuda.core._memory import Buffer +# --------------------------------------------------------------------------- +# Lazy tensor bridge (avoids loading _tensor_bridge.so until torch is used) +# --------------------------------------------------------------------------- + +cdef object _tensor_bridge = None +# Cache: type(obj) -> True/False for the torch tensor check. +# Once a type is seen, we never re-check. +cdef dict _torch_type_cache = {} +# Tri-state: None = not checked, True/False = result of version check +cdef object _torch_version_ok = None + +cdef inline bint _torch_version_check(): + """Return True if 2.3 <= torch <= 2.11 (known AOTI ABI range). Memoized. + + Lower bound: AOTI functions we use were introduced in PyTorch 2.3. + Upper bound: the ``pyobj_to_aten_handle`` trick relies on the + THPVariable struct layout (PyObject_HEAD followed by at::Tensor cdata) + and the identity ``AtenTensorHandle == at::Tensor*``. Both are + undocumented internals that could change in a future PyTorch version. + We cap at the latest version we have tested against; unknown versions + fall back to the standard DLPack/CAI paths. Bump the upper bound + after verifying a new PyTorch release. + """ + global _torch_version_ok + if _torch_version_ok is not None: + return _torch_version_ok + torch = sys.modules.get("torch") + if torch is None: + _torch_version_ok = False + return False + try: + major, minor = int(torch.__version__.split(".")[0]), \ + int(torch.__version__.split(".")[1]) + _torch_version_ok = (2, 3) <= (major, minor) <= (2, 11) + except (ValueError, IndexError): + _torch_version_ok = False + return _torch_version_ok + + +cdef inline bint _is_torch_tensor(object obj): + cdef type tp = type(obj) + cdef object cached = _torch_type_cache.get(tp) + if cached is not None: + return cached + cdef str mod = tp.__module__ or "" + cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \ + and _torch_version_check() + _torch_type_cache[tp] = result + return result + + +cdef object _get_tensor_bridge(): + """Bootstrap AOTI symbols, then import _tensor_bridge on first use.""" + global _tensor_bridge + if _tensor_bridge is not None: + return _tensor_bridge + torch_C = sys.modules.get("torch._C") + if torch_C is None: + raise RuntimeError( + "torch._C is not loaded; cannot initialise the tensor bridge. " + "Make sure PyTorch is imported before passing a torch.Tensor.") + ctypes.CDLL(torch_C.__file__, mode=ctypes.RTLD_GLOBAL) + from cuda.core import _tensor_bridge as tb + _tensor_bridge = tb + return _tensor_bridge + + try: from ml_dtypes import bfloat16 except ImportError: @@ -150,6 +219,9 @@ cdef class StridedMemoryView: Stream pointer for synchronization. If ``None``, no synchronization is performed. """ cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) + if _is_torch_tensor(obj): + _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf) + return buf view_as_dlpack(obj, stream_ptr, buf) return buf @@ -165,6 +237,9 @@ cdef class StridedMemoryView: Stream pointer for synchronization. If ``None``, no synchronization is performed. """ cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) + if _is_torch_tensor(obj): + _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf) + return buf view_as_cai(obj, stream_ptr, buf) return buf @@ -178,6 +253,9 @@ cdef class StridedMemoryView: An object implementing the `__array_interface__ `_ protocol (e.g., a numpy array). """ cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) + if _is_torch_tensor(obj): + _get_tensor_bridge().view_as_torch_tensor(obj, None, buf) + return buf view_as_array_interface(obj, buf) return buf @@ -187,6 +265,8 @@ cdef class StridedMemoryView: Tries `DLPack `_ first, then falls back to `__cuda_array_interface__ `_. + ``torch.Tensor`` objects are transparently handled via a fast AOTI path + regardless of which protocol is selected. Parameters ---------- @@ -480,6 +560,10 @@ cdef class StridedMemoryView: if self._dtype is None: if self.dl_tensor != NULL: self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype) + elif isinstance(self.metadata, int): + # AOTI dtype code stored by the torch tensor bridge + self._dtype = _get_tensor_bridge().resolve_aoti_dtype( + self.metadata) elif self.metadata is not None: self._dtype = _typestr2dtype(self.metadata["typestr"]) return self._dtype @@ -1122,6 +1206,16 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None): as_cu(h_event), producer_s)) HANDLE_RETURN(cydriver.cuStreamWaitEvent( consumer_s, as_cu(h_event), 0)) + elif _is_torch_tensor(obj): + # PyTorch's __cuda_array_interface__ reports version 2 and + # omits the "stream" field, so the standard CAI sync path + # above is a no-op for torch tensors. This is unsafe: the + # consumer has no guarantee that the producer's work is + # visible. We fix this by querying PyTorch's current CUDA + # stream via the AOTI stable C ABI and performing the same + # event-based stream ordering. + _get_tensor_bridge().sync_torch_stream( + buf.device_id, (stream_ptr)) return buf diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx new file mode 100644 index 0000000000..93c4aa47a8 --- /dev/null +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -0,0 +1,327 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tensor bridge: extract PyTorch tensor metadata via the AOTI stable C ABI. + +PyTorch is NOT required at build time. At runtime the AOTI symbols are +resolved from ``torch._C`` (which is loaded with ``RTLD_GLOBAL``). + +The ``pyobj_to_aten_handle`` trick exploits the internal layout of +``THPVariable`` (PyTorch's Python tensor wrapper):: + + struct THPVariable { + PyObject_HEAD + at::Tensor cdata; // <-- &cdata is usable as AtenTensorHandle + ... + }; + +In PyTorch 2.3–2.9 ``cdata`` was ``c10::MaybeOwned``; +from 2.10 onward it is ``at::Tensor``. In both cases ``&cdata`` +(offset ``sizeof(PyObject)`` from the start of the object) is accepted +by the AOTI stable C ABI functions as an ``AtenTensorHandle``. + +Offsetting past ``PyObject_HEAD`` gives us the handle +without any Python attribute access or method calls (~14 ns for all +7 metadata queries). + +Credit: Emilio Castillo (ecastillo@nvidia.com) – original tensor-bridge POC. + +.. note:: + + This module must NOT be imported at ``cuda.core`` load time. It is + loaded lazily (by ``_memoryview.pyx``) only when the user actually + passes a ``torch.Tensor``. The caller must ensure that + ``torch._C`` has been re-opened with ``RTLD_GLOBAL`` *before* + importing this module so that the AOTI symbols are visible. +""" + +from libc.stdint cimport intptr_t, int8_t, int16_t, int32_t, int64_t, uint8_t + +from cuda.core._memoryview cimport StridedMemoryView +from cuda.core._layout cimport _StridedLayout +from cuda.bindings cimport cydriver +from cuda.core._resource_handles cimport ( + EventHandle, + create_event_handle_noctx, + as_cu, +) +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN + +cdef extern from "Python.h": + ctypedef struct PyObject: + pass + +cdef extern from "_include/aoti_shim.h": + ctypedef int32_t AOTITorchError + + ctypedef struct AtenTensorOpaque: + pass + ctypedef AtenTensorOpaque* AtenTensorHandle + + # tensor metadata + AOTITorchError aoti_torch_get_data_ptr(AtenTensorHandle, void**) + AOTITorchError aoti_torch_get_dim(AtenTensorHandle, int64_t*) + AOTITorchError aoti_torch_get_sizes(AtenTensorHandle, int64_t**) + AOTITorchError aoti_torch_get_strides(AtenTensorHandle, int64_t**) + + # dtype + AOTITorchError aoti_torch_get_dtype(AtenTensorHandle, int32_t*) + int32_t aoti_torch_dtype_float16() + int32_t aoti_torch_dtype_float32() + int32_t aoti_torch_dtype_float64() + int32_t aoti_torch_dtype_bfloat16() + int32_t aoti_torch_dtype_uint8() + int32_t aoti_torch_dtype_int8() + int32_t aoti_torch_dtype_int16() + int32_t aoti_torch_dtype_int32() + int32_t aoti_torch_dtype_int64() + int32_t aoti_torch_dtype_bool() + int32_t aoti_torch_dtype_complex32() + int32_t aoti_torch_dtype_complex64() + int32_t aoti_torch_dtype_complex128() + + # device + AOTITorchError aoti_torch_get_device_type(AtenTensorHandle, int32_t*) + AOTITorchError aoti_torch_get_device_index(AtenTensorHandle, int32_t*) + int32_t aoti_torch_device_type_cpu() + int32_t aoti_torch_device_type_cuda() + + # stream + AOTITorchError aoti_torch_get_current_cuda_stream(int32_t, void**) + +import numpy + + +# --------------------------------------------------------------------------- +# Module-level state (initialised at import time — AOTI symbols are +# guaranteed visible because _memoryview bootstraps RTLD_GLOBAL before +# importing us) +# --------------------------------------------------------------------------- + +cdef int32_t _DEVICE_TYPE_CPU = aoti_torch_device_type_cpu() +cdef int32_t _DEVICE_TYPE_CUDA = aoti_torch_device_type_cuda() +cdef dict _aoti_dtype_map = None +cdef dict _aoti_itemsize_map = None + + +# --------------------------------------------------------------------------- +# pointer extraction +# --------------------------------------------------------------------------- + +cdef inline AtenTensorHandle pyobj_to_aten_handle(object obj): + """Extract AtenTensorHandle by offsetting past PyObject_HEAD. + + In PyTorch 2.3–2.9 the first field after PyObject_HEAD is + ``c10::MaybeOwned cdata``; from 2.10 onward it is + ``at::Tensor cdata``. In both cases the address of ``cdata`` + is usable as an ``AtenTensorHandle`` (``at::Tensor*``) for the + AOTI stable C ABI functions. + """ + return (obj + sizeof(PyObject)) + + +cdef inline int check_aoti(AOTITorchError err, const char* name) except? -1: + """Raise RuntimeError if an AOTI call returned a non-zero error code.""" + if err != 0: + raise RuntimeError(f"{name.decode()} failed") + return 0 + + +# --------------------------------------------------------------------------- +# dtype mapping (AOTI int32 -> numpy dtype) +# --------------------------------------------------------------------------- + +cdef dict _build_dtype_map(): + try: + from ml_dtypes import bfloat16 as _bf16 # noqa: F811 + has_bfloat16 = True + except ImportError: + has_bfloat16 = False + + cdef dict m = { + aoti_torch_dtype_float16(): numpy.dtype(numpy.float16), + aoti_torch_dtype_float32(): numpy.dtype(numpy.float32), + aoti_torch_dtype_float64(): numpy.dtype(numpy.float64), + aoti_torch_dtype_uint8(): numpy.dtype(numpy.uint8), + aoti_torch_dtype_int8(): numpy.dtype(numpy.int8), + aoti_torch_dtype_int16(): numpy.dtype(numpy.int16), + aoti_torch_dtype_int32(): numpy.dtype(numpy.int32), + aoti_torch_dtype_int64(): numpy.dtype(numpy.int64), + aoti_torch_dtype_bool(): numpy.dtype(numpy.bool_), + aoti_torch_dtype_complex64(): numpy.dtype(numpy.complex64), + aoti_torch_dtype_complex128(): numpy.dtype(numpy.complex128), + } + if has_bfloat16: + m[aoti_torch_dtype_bfloat16()] = numpy.dtype(_bf16) + return m + + +cdef object _get_aoti_dtype(int32_t dtype_code): + global _aoti_dtype_map + if _aoti_dtype_map is None: + _aoti_dtype_map = _build_dtype_map() + result = _aoti_dtype_map.get(dtype_code) + if result is None: + raise TypeError(f"Unsupported AOTI dtype code: {dtype_code}") + return result + + +def resolve_aoti_dtype(int32_t dtype_code): + """Python-callable wrapper around _get_aoti_dtype (for lazy resolution).""" + return _get_aoti_dtype(dtype_code) + + +cdef dict _build_itemsize_map(): + return { + aoti_torch_dtype_bool(): sizeof(uint8_t), + aoti_torch_dtype_uint8(): sizeof(uint8_t), + aoti_torch_dtype_int8(): sizeof(int8_t), + aoti_torch_dtype_float16(): sizeof(int16_t), # no C float16 + aoti_torch_dtype_bfloat16(): sizeof(int16_t), # no C bfloat16 + aoti_torch_dtype_int16(): sizeof(int16_t), + aoti_torch_dtype_complex32(): 2 * sizeof(int16_t), # no C complex32 + aoti_torch_dtype_float32(): sizeof(float), + aoti_torch_dtype_int32(): sizeof(int32_t), + aoti_torch_dtype_complex64(): 2 * sizeof(float), + aoti_torch_dtype_float64(): sizeof(double), + aoti_torch_dtype_int64(): sizeof(int64_t), + aoti_torch_dtype_complex128(): 2 * sizeof(double), + } + + +cdef int _get_aoti_itemsize(int32_t dtype_code) except -1: + global _aoti_itemsize_map + if _aoti_itemsize_map is None: + _aoti_itemsize_map = _build_itemsize_map() + result = _aoti_itemsize_map.get(dtype_code) + if result is None: + raise TypeError(f"Unsupported AOTI dtype code: {dtype_code}") + return result + + +# --------------------------------------------------------------------------- +# Stream ordering helper +# --------------------------------------------------------------------------- + +cpdef int sync_torch_stream(int32_t device_index, + intptr_t consumer_s) except? -1: + """Establish stream ordering between PyTorch's current CUDA stream + and the given consumer stream. + + Records an event on PyTorch's current stream (the producer) and makes + the consumer stream wait on it. This is a no-op if both streams are + the same. + """ + cdef void* producer_s + cdef EventHandle h_event + + check_aoti(aoti_torch_get_current_cuda_stream(device_index, &producer_s), + b"aoti_torch_get_current_cuda_stream") + if producer_s != consumer_s: + with nogil: + h_event = create_event_handle_noctx( + cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING) + HANDLE_RETURN(cydriver.cuEventRecord( + as_cu(h_event), producer_s)) + HANDLE_RETURN(cydriver.cuStreamWaitEvent( + consumer_s, as_cu(h_event), 0)) + return 0 + + +# --------------------------------------------------------------------------- +# Public API: construct StridedMemoryView from a torch.Tensor +# --------------------------------------------------------------------------- + +def view_as_torch_tensor(object obj, object stream_ptr, view=None): + """Create/populate a :class:`StridedMemoryView` from a ``torch.Tensor``. + + This is a fast path that avoids DLPack/CAI protocol overhead by + reading tensor metadata directly through the AOTI stable C ABI. + + Parameters + ---------- + obj : torch.Tensor + The source tensor. + stream_ptr : int or None + Consumer stream pointer. When not ``-1``, stream ordering is + established between PyTorch's current CUDA stream (the producer) + and the consumer stream, matching the DLPack contract. + view : StridedMemoryView, optional + If provided, populate this existing view in-place. Otherwise a + new instance is created. + """ + cdef AtenTensorHandle handle = pyobj_to_aten_handle(obj) + cdef void* data_ptr + cdef int64_t ndim + cdef int64_t* sizes_ptr + cdef int64_t* strides_ptr + cdef int32_t dtype_code + cdef int32_t device_type, device_index + cdef StridedMemoryView buf + cdef int itemsize + cdef intptr_t _stream_ptr_int + cdef _StridedLayout layout + + check_aoti(aoti_torch_get_data_ptr(handle, &data_ptr), + b"aoti_torch_get_data_ptr") + check_aoti(aoti_torch_get_dim(handle, &ndim), + b"aoti_torch_get_dim") + check_aoti(aoti_torch_get_sizes(handle, &sizes_ptr), + b"aoti_torch_get_sizes") + check_aoti(aoti_torch_get_strides(handle, &strides_ptr), + b"aoti_torch_get_strides") + check_aoti(aoti_torch_get_dtype(handle, &dtype_code), + b"aoti_torch_get_dtype") + check_aoti(aoti_torch_get_device_type(handle, &device_type), + b"aoti_torch_get_device_type") + check_aoti(aoti_torch_get_device_index(handle, &device_index), + b"aoti_torch_get_device_index") + + # -- populate StridedMemoryView -- + if view is not None: + buf = view + else: + buf = StridedMemoryView.__new__(StridedMemoryView) + + buf.ptr = data_ptr + buf.readonly = False + buf.exporting_obj = obj + buf.dl_tensor = NULL + buf.metadata = None + buf._buffer = None + + if device_type == _DEVICE_TYPE_CPU: + buf.device_id = -1 + buf.is_device_accessible = False + elif device_type == _DEVICE_TYPE_CUDA: + buf.device_id = device_index + buf.is_device_accessible = True + + # -- stream ordering (matches the DLPack contract) -- + if stream_ptr is not None: + _stream_ptr_int = int(stream_ptr) + if _stream_ptr_int != -1: + sync_torch_stream(device_index, _stream_ptr_int) + else: + raise BufferError( + f"Unsupported device type from torch tensor " + f"(AOTI device type id: {device_type})") + + # Defer full numpy dtype resolution until first .dtype access. + # Store the raw AOTI dtype code in metadata for lazy lookup. + buf.metadata = dtype_code + + # Build _StridedLayout. init_from_ptr copies shape/strides so we are + # safe even though they are borrowed pointers. + itemsize = _get_aoti_itemsize(dtype_code) + layout = _StridedLayout.__new__(_StridedLayout) + layout.init_from_ptr( + ndim, + sizes_ptr, + strides_ptr, + itemsize, + ) + buf._layout = layout + + return buf diff --git a/cuda_core/docs/source/release/1.0.0-notes.rst b/cuda_core/docs/source/release/1.0.0-notes.rst new file mode 100644 index 0000000000..34eff57100 --- /dev/null +++ b/cuda_core/docs/source/release/1.0.0-notes.rst @@ -0,0 +1,35 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +.. currentmodule:: cuda.core + +``cuda.core`` 1.0.0 Release Notes +================================= + + +Highlights +---------- + +- TBD + + +New features +------------ + +- TBD + + +Fixes and enhancements +----------------------- + +- :class:`~utils.StridedMemoryView` now provides a fast path for ``torch.Tensor`` + objects via PyTorch's AOT Inductor (AOTI) stable C ABI. When a ``torch.Tensor`` + is passed to any ``from_*`` classmethod (``from_dlpack``, + ``from_cuda_array_interface``, ``from_array_interface``, or + ``from_any_interface``), tensor metadata is read directly from the underlying + C struct, bypassing the DLPack and CUDA Array Interface protocol overhead. + This yields ~7-20x faster ``StridedMemoryView`` construction for PyTorch + tensors (depending on whether stream ordering is required). Proper CUDA stream ordering is established between PyTorch's current + stream and the consumer stream, matching the DLPack synchronization contract. + Requires PyTorch >= 2.3. + (`#749 `__) diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index 59829f8fb3..8874fe1e0a 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -712,3 +712,147 @@ def test_ml_dtypes_bfloat16_dlpack_requires_ml_dtypes(init_cuda, no_ml_dtypes, a smv = api(a, stream_ptr=0) with pytest.raises(NotImplementedError, match=r"requires `ml_dtypes`"): smv.dtype # noqa: B018 + + +# =================================================================== +# Tensor bridge (torch.Tensor fast path via AOTI stable C ABI) +# =================================================================== + +_torch_skip = pytest.mark.skipif(torch is None, reason="PyTorch is not installed") + + +@_torch_skip +@pytest.mark.parametrize( + "dtype", + [ + pytest.param("float16", id="float16"), + pytest.param("float32", id="float32"), + pytest.param("float64", id="float64"), + pytest.param("int8", id="int8"), + pytest.param("int16", id="int16"), + pytest.param("int32", id="int32"), + pytest.param("int64", id="int64"), + pytest.param("uint8", id="uint8"), + pytest.param("bool", id="bool"), + pytest.param("complex64", id="complex64"), + pytest.param("complex128", id="complex128"), + ], +) +def test_torch_tensor_bridge_dtypes(init_cuda, dtype): + """Verify that dtype mapping via the tensor bridge matches torch's own dtype.""" + torch_dtype = getattr(torch, dtype) + a = torch.tensor([1, 0, 1], dtype=torch_dtype, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.dtype.itemsize == a.element_size() + assert smv.ptr == a.data_ptr() + + +@_torch_skip +@pytest.mark.skipif(ml_dtypes is None, reason="ml_dtypes is not installed") +def test_torch_tensor_bridge_bfloat16(init_cuda): + a = torch.tensor([1, 2, 3], dtype=torch.bfloat16, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.dtype == np.dtype("bfloat16") + assert smv.ptr == a.data_ptr() + + +@_torch_skip +def test_torch_tensor_bridge_cuda_1d(init_cuda): + a = torch.arange(12, dtype=torch.float32, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.ptr == a.data_ptr() + assert smv.shape == (12,) + assert smv.strides in (None, (1,)) # C-contiguous may be None + assert smv.dtype == np.dtype(np.float32) + assert smv.device_id == init_cuda.device_id + assert smv.is_device_accessible is True + assert smv.readonly is False + assert smv.exporting_obj is a + + +@_torch_skip +def test_torch_tensor_bridge_cuda_nd(init_cuda): + a = torch.arange(24, dtype=torch.float32, device="cuda").reshape(2, 3, 4) + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.ptr == a.data_ptr() + assert smv.shape == (2, 3, 4) + assert smv.dtype == np.dtype(np.float32) + assert smv.device_id == init_cuda.device_id + assert smv.is_device_accessible is True + + +@_torch_skip +def test_torch_tensor_bridge_non_contiguous(init_cuda): + """Transposed tensor should have non-trivial strides.""" + a = torch.arange(12, dtype=torch.float32, device="cuda").reshape(3, 4).t() + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.shape == (4, 3) + # torch.stride() returns element counts, same as StridedMemoryView + assert smv.strides == tuple(a.stride()) + assert smv.ptr == a.data_ptr() + + +@_torch_skip +def test_torch_tensor_bridge_sliced(init_cuda): + """Sliced tensor should have correct data_ptr (accounts for storage offset).""" + base = torch.arange(100, dtype=torch.int64, device="cuda") + a = base[10:20] + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.ptr == a.data_ptr() + assert smv.shape == (10,) + assert smv.dtype == np.dtype(np.int64) + + +@_torch_skip +def test_torch_tensor_bridge_sliced_2d(init_cuda): + """2D sliced tensor should have correct data_ptr, shape, and strides.""" + base = torch.arange(60, dtype=torch.float32, device="cuda").reshape(6, 10) + a = base[1:4, 2:7] + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.ptr == a.data_ptr() + assert smv.shape == (3, 5) + assert smv.strides == (10, 1) # element strides + assert smv.dtype == np.dtype(np.float32) + + +@_torch_skip +def test_torch_tensor_bridge_scalar(init_cuda): + a = torch.tensor(42.0, dtype=torch.float32, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.ptr == a.data_ptr() + assert smv.shape == () + assert smv.dtype == np.dtype(np.float32) + + +@_torch_skip +def test_torch_tensor_bridge_empty(init_cuda): + a = torch.empty(0, dtype=torch.float32, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.shape == (0,) + assert smv.dtype == np.dtype(np.float32) + + +@_torch_skip +def test_torch_tensor_bridge_cpu(init_cuda): + a = torch.arange(5, dtype=torch.float32, device="cpu") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=-1) + assert smv.ptr == a.data_ptr() + assert smv.shape == (5,) + assert smv.device_id == -1 + assert smv.is_device_accessible is False + + +@_torch_skip +def test_torch_tensor_bridge_decorator(init_cuda): + """Verify tensor bridge works through the args_viewable_as_strided_memory decorator.""" + + @args_viewable_as_strided_memory((0,)) + def fn(tensor, stream): + return tensor.view(stream.handle) + + a = torch.arange(6, dtype=torch.float32, device="cuda").reshape(2, 3) + stream = Device().create_stream() + smv = fn(a, stream) + assert smv.ptr == a.data_ptr() + assert smv.shape == (2, 3) + assert smv.dtype == np.dtype(np.float32)