From 231b635b1e38fc75851212fdf8e9a6f588f7c882 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 9 Jun 2026 21:49:54 +0000 Subject: [PATCH 1/2] Add FlexQuantization Signed-off-by: Kaining Zhong --- build_tools/pytorch.py | 40 +- pyproject.toml | 2 +- .../cutedsl/test_flex_mxfp8_quantization.py | 111 ++ .../common/cutedsl/cast/mxfp8_quantization.py | 747 +++++++++++ .../common/cutedsl/cast/quantization_utils.py | 1125 +++++++++++++++++ .../common/cutedsl/cutedsl_utils.py | 34 + .../transformer_engine/transformer_engine.h | 3 + transformer_engine/pytorch/csrc/common.h | 38 + transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/cast.cpp | 25 + .../pytorch/csrc/extensions/pybind.cpp | 23 +- transformer_engine/pytorch/csrc/pybind.h | 15 +- transformer_engine/pytorch/csrc/quantizer.cpp | 586 ++++++++- .../pytorch/csrc/tvm_ffi_bridge.h | 171 +++ .../pytorch/csrc/type_converters.cpp | 46 + .../pytorch/tensor/flex_tensor.py | 864 +++++++++++++ .../pytorch/tensor/float8_blockwise_tensor.py | 28 +- .../pytorch/tensor/mxfp8_tensor.py | 28 +- .../pytorch/tensor/nvfp4_tensor.py | 27 +- .../tensor/storage/flex_tensor_storage.py | 393 ++++++ .../tensor/storage/mxfp8_tensor_storage.py | 15 +- .../tensor/storage/nvfp4_tensor_storage.py | 13 +- transformer_engine/pytorch/utils.py | 16 +- 23 files changed, 4245 insertions(+), 107 deletions(-) create mode 100644 tests/pytorch/cutedsl/test_flex_mxfp8_quantization.py create mode 100644 transformer_engine/common/cutedsl/cast/mxfp8_quantization.py create mode 100644 transformer_engine/common/cutedsl/cast/quantization_utils.py create mode 100644 transformer_engine/common/cutedsl/cutedsl_utils.py create mode 100644 transformer_engine/pytorch/csrc/tvm_ffi_bridge.h create mode 100644 transformer_engine/pytorch/tensor/flex_tensor.py create mode 100644 transformer_engine/pytorch/tensor/storage/flex_tensor_storage.py diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index e2e6d09c29..d8eba20bcc 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -4,6 +4,7 @@ """PyTorch related extensions.""" +import importlib.util import os from pathlib import Path from importlib import metadata @@ -22,7 +23,17 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions.""" - return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"] + return [ + "torch>=2.1", + "einops", + "onnxscript", + "onnx", + "packaging", + "pydantic", + "nvdlfw-inspect", + "apache-tvm-ffi", + "nvidia-cutlass-dsl>=4.5.0", + ] def test_requirements() -> List[str]: @@ -58,6 +69,25 @@ def setup_pytorch_extension( ] ) + # apache-tvm-ffi: headers for the C++ API (Module / Function / TensorView) + # and libtvm_ffi.so for symbol resolution. Used by tvm_ffi_bridge.h / + # applyTVMFunction. Python registers AOT-compiled CuTeDSL kernels into + # the global registry; TE C++ looks them up via Function::GetGlobalRequired. + tvm_ffi_spec = importlib.util.find_spec("tvm_ffi") + if tvm_ffi_spec is None or not tvm_ffi_spec.submodule_search_locations: + raise RuntimeError( + "apache-tvm-ffi package not found; install it (e.g. " + "`pip install apache-tvm-ffi`) — required for the TVM FFI bridge." + ) + tvm_ffi_root = Path(tvm_ffi_spec.submodule_search_locations[0]) + tvm_ffi_include = tvm_ffi_root / "include" + tvm_ffi_lib_dir = tvm_ffi_root / "lib" + if not tvm_ffi_include.is_dir() or not (tvm_ffi_lib_dir / "libtvm_ffi.so").exists(): + raise RuntimeError( + f"apache-tvm-ffi assets missing at {tvm_ffi_root} (need include/ and lib/libtvm_ffi.so)" + ) + include_dirs.append(tvm_ffi_include) + # Compiler flags cxx_flags = ["-O3", "-fvisibility=hidden"] if debug_build_enabled(): @@ -77,8 +107,11 @@ def setup_pytorch_extension( setup_mpi_flags(include_dirs, cxx_flags) - library_dirs = [] - libraries = [] + library_dirs = [tvm_ffi_lib_dir] + libraries = ["tvm_ffi"] + # rpath pinned to the pip install dir so the loader finds libtvm_ffi.so + # without LD_LIBRARY_PATH at runtime. + extra_link_args = [f"-Wl,-rpath,{tvm_ffi_lib_dir}"] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): assert ( os.getenv("NVSHMEM_HOME") is not None @@ -102,6 +135,7 @@ def setup_pytorch_extension( sources=[str(src) for src in sources], include_dirs=[str(inc) for inc in include_dirs], extra_compile_args={"cxx": cxx_flags}, + extra_link_args=extra_link_args, libraries=[str(lib) for lib in libraries], library_dirs=[str(lib_dir) for lib_dir in library_dirs], ) diff --git a/pyproject.toml b/pyproject.toml index 4a8fded172..9df6aa4ac9 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1", "apache-tvm-ffi"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/tests/pytorch/cutedsl/test_flex_mxfp8_quantization.py b/tests/pytorch/cutedsl/test_flex_mxfp8_quantization.py new file mode 100644 index 0000000000..39558d3754 --- /dev/null +++ b/tests/pytorch/cutedsl/test_flex_mxfp8_quantization.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +from transformer_engine.common.cutedsl.cutedsl_utils import str_to_te_dtype +import transformer_engine.pytorch # noqa: F401 (loads libtransformer_engine.so) +import transformer_engine_torch as tex +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.common.cutedsl.cast.mxfp8_quantization import ( + get_mxfp8_quantizer, +) + +MXFP8_BLOCK = 32 # MXFP8 scale block size; valid shapes must be multiples of this. + +# 2 aligned (no scale padding) + 2 padded (partial tiles); +SHAPES = [(256, 256), (128, 512), (96, 224), (160, 96)] + +def get_dtype_combinations(): + dtype_row = ("e4m3", "e5m2", "none") + dtype_column = ("e4m3", "e5m2", "none") + return [(r, c) for r in dtype_row for c in dtype_column] + +DTYPE_PAIRS = get_dtype_combinations() + +def reference_quantize(x, fp8_type, rowwise, columnwise, swizzle): + q = MXFP8Quantizer(fp8_dtype=str_to_te_dtype(fp8_type), rowwise=rowwise, columnwise=columnwise) + q.optimize_for_gemm = swizzle # makes the native kernel emit swizzled scales + ref = tex.quantize(x.clone(), q) + return ref + +@pytest.mark.parametrize("swizzle", [False, True]) +@pytest.mark.parametrize("dtype_pair", DTYPE_PAIRS) +@pytest.mark.parametrize("shape", SHAPES) +def test_flex_mxfp8_bitexact(shape, dtype_pair, swizzle): + M, N = shape + dtype_row, dtype_column = dtype_pair + torch.manual_seed(0) + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + + # No direction is invalid -- the quantizer must reject it at construction. + if dtype_row == "none" and dtype_column == "none": + with pytest.raises(ValueError): + get_mxfp8_quantizer(x, dtype_row, dtype_column, with_gemm_swizzled_scales=swizzle) + return + + flex_q = get_mxfp8_quantizer( + x, dtype_row=dtype_row, dtype_col=dtype_column, with_gemm_swizzled_scales=swizzle + ) + flex = tex.quantize(x, flex_q) + torch.cuda.synchronize() + + if dtype_row != "none": + scale_M, scale_N = M, N // MXFP8_BLOCK + # Reference for this direction uses THIS direction's dtype. + ref = reference_quantize(x, dtype_row, rowwise=True, columnwise=False, swizzle=swizzle) + assert ref._rowwise_data.shape == flex._rowwise_data.shape, "rowwise data shape mismatch" + assert ref._rowwise_scale_inv.shape == flex._rowwise_scale_inv.shape, "rowwise scale shape mismatch" + torch.testing.assert_close(flex._rowwise_data, ref._rowwise_data, rtol=0, atol=0) # bit-identical + if swizzle: + torch.testing.assert_close(flex._rowwise_scale_inv, ref._rowwise_scale_inv, rtol=0, atol=0) + else: + torch.testing.assert_close( + flex._rowwise_scale_inv[:scale_M, :scale_N], + ref._rowwise_scale_inv[:scale_M, :scale_N], + rtol=0, atol=0 + ) + else: + assert flex._rowwise_data is None, "row=none must not produce rowwise data" + + if dtype_column != "none": + scale_M, scale_N = M // MXFP8_BLOCK, N + ref = reference_quantize(x, dtype_column, rowwise=False, columnwise=True, swizzle=swizzle) + assert ref._columnwise_data.shape == flex._columnwise_data.shape, "columnwise data shape mismatch" + assert ref._columnwise_scale_inv.shape == flex._columnwise_scale_inv.shape, "columnwise scale shape mismatch" + torch.testing.assert_close(flex._columnwise_data, ref._columnwise_data, rtol=0, atol=0) # bit-identical + if swizzle: + torch.testing.assert_close(flex._columnwise_scale_inv, ref._columnwise_scale_inv, rtol=0, atol=0) + else: + torch.testing.assert_close( + flex._columnwise_scale_inv[:scale_M, :scale_N], + ref._columnwise_scale_inv[:scale_M, :scale_N], + rtol=0, atol=0 + ) + else: + assert flex._columnwise_data is None, "col=none must not produce colwise data" + +def test_flex_mxfp8_wrong_shape(): + """A quantizer is compiled for a specific (M, N); using it on a different N + must error rather than silently mis-quantize. + + The kernel name is the cache key encoding the baked (constexpr) shape, so + the registered kernel only accepts tensors of that shape -- feeding it a + different N trips the compiled entry's shape guarantee. + """ + M, N = (128, 256) + x1 = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + flex_q = get_mxfp8_quantizer(x1, dtype_row="e4m3", dtype_col="e5m2") + + tex.quantize(x1, flex_q) # sanity: works for the shape it was compiled for + + # Changed N: the AOT entry was compiled with literal shapes, so its baked + # per-arg shape check rejects the mismatched tensor before the kernel runs + # (e.g. "Mismatched mX.shape[1] ..."), rather than silently mis-quantizing. + # `match` keeps the test from passing on some unrelated failure. + x2 = torch.randn(M, N * 2, dtype=torch.bfloat16, device="cuda") + with pytest.raises(RuntimeError, match="[Mm]ismatch"): + tex.quantize(x2, flex_q) + torch.cuda.synchronize() diff --git a/transformer_engine/common/cutedsl/cast/mxfp8_quantization.py b/transformer_engine/common/cutedsl/cast/mxfp8_quantization.py new file mode 100644 index 0000000000..350c8c61bc --- /dev/null +++ b/transformer_engine/common/cutedsl/cast/mxfp8_quantization.py @@ -0,0 +1,747 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""MXFP8 quantization kernel implemented in CuTeDSL. + +Replicates the core logic of quantize_mxfp8.cuh: given a 2D tensor of BF16/FP16 +values, quantize to MXFP8 format (FP8E4M3 data + E8M0 per-block scales). + +Matches the C++ kernel's tile dimensions and thread layout: + CHUNK_DIM_Y = 64, CHUNK_DIM_X = 64, THREADS_PER_CHUNK = 64 + BUFF_DIM_Y = 32, BUFF_DIM_X = 64, STAGES = 2 + SCALE_DIM = 32 (elements per MXFP8 scaling block) + +Grid: (ceil(N / 64), ceil(M / 64)) +Each block processes a 64x64 chunk in 2 stages of 32x64 tiles loaded into +shared memory. +""" + +from transformer_engine.common.cutedsl.cutedsl_utils import str_to_te_dtype, torch_to_cutlass_dtype +from transformer_engine.pytorch.tensor.flex_tensor import FlexQuantizer + +from typing import Literal, Optional, Type, Union + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +from cutlass import Float32, Int32, Uint8 + +import hashlib +import tvm_ffi as _tvm_ffi + +from transformer_engine.common.cutedsl.cast.quantization_utils import ( + FP8E4M3_MAX_NORM_RCP, + FP8E5M2_MAX_NORM_RCP, + _bitcast_f32_to_i32, + quantize_colwise_mxfp8, + quantize_rowwise_mxfp8, +) + +# MXFP8 settings +MXFP8_BLOCK_SIZE = 32 # Number of elements per MXFP8 scale block. They will share the same E8M0 scale factor +SCALE_DIM = MXFP8_BLOCK_SIZE + +# Double-buffering for async copy + compute overlap +BUFFER_NUM = 2 + +# Vectorised access constants for bank-conflict avoidance (rowwise pass) +PACK_SIZE = 4 # Elements per vector load +WAVES = SCALE_DIM // PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total +THREADS_PER_WARP = 32 +TOTAL_BANKS_WIDTH = (32 * 4) // 1 # 32 banks × 4 bytes, in bytes (uint8 stride) +THREADS_PER_BANK = TOTAL_BANKS_WIDTH // SCALE_DIM # 4 threads per bank + +# Tiling sizes +NUM_STAGES = 2 # Pipeline depth of the producer/consumer ring buffer for the TMA-G2S input loads (PipelineTmaAsync stage count) +NUM_TILES = 2 # Each CTA process 2 tiles along the Y (row, slowest-changing) dimension +TILE_Y = 32 # Each tile has 32 rows, so each CTA handles 32 * 2 rows in total +TILE_X = 64 # Each tile has 64 columns + +# CTA size +THREADS_PER_CHUNK = 64 +NUM_WARPS = THREADS_PER_CHUNK // 32 + +class MXFP8QuantizeConfig: + + def __init__(self, + dtype: torch.dtype, + dtype_row: Union[Literal["e4m3", "e5m2", "none"]], + dtype_column: Union[Literal["e4m3", "e5m2", "none"]], + with_gemm_swizzled_scales=False): + self.DTYPE = dtype + self.DTYPE_ROW = dtype_row + self.ROWWISE = dtype_row != "none" + self.COLUMNWISE = dtype_column != "none" + self.DTYPE_COLUMN = dtype_column + self.WITH_GEMM_SWIZZLED_SCALES = with_gemm_swizzled_scales + # No amax / no activation for the MXFP8 path (kept as + # const-expr-false flags so the kernel's amax/activation branches are + # dead-stripped). NVFP4 (which needs amax) will flip these later. + self.WITH_AMAX = False + self.ACTIVATION = None + +class MXFP8QuantizeKernel: + """MXFP8 quantization with shared-memory tiling (rowwise, colwise, or both). + + Matches C++ kernel's BIDIMENSIONAL scaling mode: + Grid (ceil(N/64), ceil(M/64)) + Block (64) + Each block processes a 64x64 chunk in 2 stages of 32x64. + + Per stage, the tile is loaded into shared memory once. The colwise + pass reads columns from smem first, then the rowwise pass reads rows. + When both directions are enabled, global memory is read only once per + element — matching the C++ single-pass behaviour. + + Thread mappings (per stage): + Colwise: thread tidx handles column tidx, 32 rows (stride BUFF_DIM_X). + Rowwise: tid_Y = tidx // 2 -> row, tid_X = tidx % 2 -> scale-block. + """ + + def __init__(self, cfg): + self.cfg = cfg + + @cute.jit + def __call__( + self, + mX: cute.Tensor, # Input tensor to quantize + mO_row: Optional[cute.Tensor], mS_row: Optional[cute.Tensor], # Rowwise data + scale + mA_row: Optional[cute.Tensor], # Rowwise amax (None for MXFP8) + mO_col: Optional[cute.Tensor], mS_col: Optional[cute.Tensor], # Colwise data + scale + mA_col: Optional[cute.Tensor], # Colwise amax (None for MXFP8) + rng_state: Optional[cute.Tensor], # SR seed/offset (None when SR disabled) + stream: cuda.CUstream, # launch stream (C++ passes the handle as an int64 scalar) + ): + M = mX.shape[0] + N = mX.shape[1] + cfg = self.cfg + num_scale_cols = N // SCALE_DIM + num_scale_rows = M // SCALE_DIM + + # Rewrap mS_row / mS_col with the GEMM-swizzled layout when requested. + # Wrapper passes in a tensor with the compact (M, N/32):(N/32, 1) layout + # (built from a compact fake-ptr at compile time), and we re-view the + # underlying buffer here so the per-block scale stores below land at the + # cuBLAS-swizzled byte offsets. + # See https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout + # and swizzle_demo.svg for a visual of the byte permutation. + if cutlass.const_expr(cfg.WITH_GEMM_SWIZZLED_SCALES): + num_tiles_M = (M + 127) // 128 + num_tiles_SC = (num_scale_cols + 3) // 4 # = ceil(N / 128) + num_tiles_SR = (num_scale_rows + 3) // 4 # = ceil(M / 128) + num_tiles_N = (N + 127) // 128 + # row i = i_lo + 32 * (i_hi + 4 * tile_Y); col j = j_lo + 4 * tile_X. + # Within one 128×4 tile: byte = i_lo*16 + i_hi*4 + j_lo. + + # Tile-major outer dims add (tile_Y * num_tiles_SC + tile_X) * 512. + # For example, if M=256, N=512, then num_scale_cols = 16, num_scale_rows = 8, and num_tiles_M=2, num_tiles_SC=4, num_tiles_SR=2, num_tiles_N=4 + # The swizzled layout is ((32, 4, 2), (4, 4)):((16, 4, 2048), (1, 512)) + if cutlass.const_expr(cfg.ROWWISE): + mS_row = cute.make_tensor( + mS_row.iterator, + cute.make_layout( + ((32, 4, num_tiles_M), (4, num_tiles_SC)), + stride=((16, 4, num_tiles_SC * 512), (1, 512)), + ), + ) + # Colwise: same swizzle, axes swap roles — col axis gets the 32×4 + # inner decomp, scale-row axis gets the 4-extent dim. + if cutlass.const_expr(cfg.COLUMNWISE): + mS_col = cute.make_tensor( + mS_col.iterator, + cute.make_layout( + ((4, num_tiles_SR), (32, 4, num_tiles_N)), + stride=((1, 512), (16, 4, num_tiles_SR * 512)), + ), + ) + + # Divide by the STAGE tile (TILE_Y, TILE_X // SCALE_DIM), not the CTA + # tile. Each CTA owns NUM_TILES consecutive row-tiles; the kernel walks + # them by indexing GRID's row dim with `bidy * NUM_TILES + stage` (cute + # auto-decomposes a flat coord onto GRID's hierarchical row modes). + # + # Critically, this is the only divide that cleanly cuts both layouts: + # - compact `(M, N/32):(N/32, 1)` → SCALE_TILE = (32, 2):(N/32, 1) + # - swizzled `((32,4,n_M),(4,n_SC)):((16,4,n_SC·512),(1,512))` + # → SCALE_TILE = (32, 2):(16, 1) + # The bigger (TILE_Y * NUM_TILES, ...) divide we used before tangles the + # swizzle's (32, 4) row hierarchy under flatten + sub-divide chain. + + # Declare TMA descriptors on the host side. + # make_tiled_tma_atom returns the UNTILED gmem tensor with basis strides. + # Tile it inside the kernel with zipped_divide so each coord selects + # one (TILE_Y, TILE_X) tile. + smem_tile_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) + cta_tiler = (TILE_Y, TILE_X) + + # Input: TMA G2S (bf16/fp16 → smem). + op_load = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() + tma_atom, tma_src = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_load, mX, smem_tile_layout, cta_tiler, num_multicast=1, + ) + + # Output: TMA S2G (uint8 smem → gmem) for both directions. Creating + # both atoms unconditionally — if a direction is disabled the kernel + # simply won't dispatch its copy, and the atom cost is negligible. + op_store = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() + out_smem_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) + tma_atom_out_row = None + tma_dst_out_row = None + tma_atom_out_col = None + tma_dst_out_col = None + if cutlass.const_expr(cfg.ROWWISE): + tma_atom_out_row, tma_dst_out_row = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_store, mO_row, out_smem_layout, cta_tiler, num_multicast=1, + ) + if cutlass.const_expr(cfg.COLUMNWISE): + tma_atom_out_col, tma_dst_out_col = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_store, mO_col, out_smem_layout, cta_tiler, num_multicast=1, + ) + + # CUDA launches in (0,0), (1,0), (2,0)... order, so we should make N the leading dimension for better access pattern + # So consecutive blocks will move along the N dimension first, which is the innermost dimension in memory and we can use cache better + grid = [ + cute.ceil_div(Int32(N), TILE_X), + cute.ceil_div(M, TILE_Y * NUM_TILES), + ] + block = [THREADS_PER_CHUNK,] + + self.kernel( + mX, mS_row, mS_col, None, # mAmax = None (no amax for the MXFP8 path) + mX.element_type, + tma_atom, tma_src, + tma_atom_out_row, tma_dst_out_row, + tma_atom_out_col, tma_dst_out_col, + ).launch( + grid=grid, + block=block, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mX, + mS_row, + mS_col, + mAmax, + dtype: cutlass.Constexpr[Type[cutlass.Numeric]], + tma_atom, tma_src, # how to use TMA to copy the input + tma_atom_out_row, tma_dst_out_row, # how to use TMA to copy the rowwise output + tma_atom_out_col, tma_dst_out_col, # how to use TMA to copy the colwise output + ): + cfg = self.cfg + + if cutlass.const_expr(cfg.ROWWISE): + mS_row = cute.zipped_divide(mS_row, (TILE_Y, TILE_X // SCALE_DIM)) + if cutlass.const_expr(cfg.COLUMNWISE): + mS_col = cute.zipped_divide(mS_col, (TILE_Y // SCALE_DIM, TILE_X)) + # For M=256, N=512: + # Non-swizzled: https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=zipped_divide-%28256%2C+16%29%3A%2816%2C+1%29-32%0A2 + # Swizzled: https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=zipped_divide-%28%2832%2C+4%2C+2%29%2C+%284%2C+4%29%29%3A%28%2816%2C+4%2C+2048%29%2C+%281%2C+512%29%29-32%0A2 + # print(f"mS_row after zipped_divide: {mS_row}") + + # FP8 output smem, one 32×64 tile per stage per enabled direction. + # Allocating a dead sO_col in rowwise-only (or sO_row in colwise-only) + # bumps per-CTA smem from 12 KB to 16 KB, which drops occupancy and + # regresses the single-direction path by ~8-10% at 16384^2. Match + # C++ and only allocate what the active pass actually uses. + # sAmax holds one f32 per warp for the cross-warp amax reduction — + # negligible (8 bytes for NUM_WARPS=2) and we always allocate so the + # struct doesn't fork on a 4th const-expr (cfg.WITH_AMAX) dimension. + if cutlass.const_expr(cfg.ROWWISE and cfg.COLUMNWISE): + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_row: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_col: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + elif cutlass.const_expr(cfg.ROWWISE and not cfg.COLUMNWISE): + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_row: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + elif cutlass.const_expr(cfg.ROWWISE): + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_row: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + else: + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_col: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Per-stage shmem tile is 2D (TILE_Y, TILE_X); stages laid out back-to-back. + # Mode 0 is hierarchical ((TILE_Y, TILE_X),) so it matches the rank/shape + # of gX_tiled[(None, (ty, tx))] produced by zipped_divide. + # sX[(None, stage)] selects one (TILE_Y, TILE_X) tile. + sX = storage.sX.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + if cutlass.const_expr(cfg.ROWWISE): + sO_row = storage.sO_row.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + if cutlass.const_expr(cfg.COLUMNWISE): + sO_col = storage.sO_col.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # Prefetch TMA descriptor (one-time; warp-0 only). + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom) + + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, _ = cute.arch.block_idx() + + # Producer: `arrive_and_expect_tx` is wrapped in `elect_one`, so only + # one lane of warp 0 arrives on the full barrier per stage → arrive_count=1. + producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + # Consumer: `consumer_release` arrives only on the `is_signalling_thread` + # (lane 0 of each warp), so arrive_count = num_warps per stage. + num_warps = THREADS_PER_CHUNK // 32 + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_warps) + + # Bytes transferred per TMA copy: one (TILE_Y, TILE_X) tile of dtype. + tx_count = TILE_Y * TILE_X * dtype.width // 8 + + mainloop_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_storage.data_ptr(), + num_stages=NUM_STAGES, + producer_group=producer_group, + consumer_group=consumer_group, + tx_count=tx_count, + cta_layout_vmnk=None, # single-CTA, no cluster/multicast + ) + + prod_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, NUM_STAGES + ) + cons_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, NUM_STAGES + ) + + M = mX.shape[0] + N = mX.shape[1] + + num_tiles = cutlass.min( + NUM_TILES, + cute.ceil_div(M - bidy * TILE_Y * NUM_TILES, TILE_Y), + ) + + # Tile the TMA gmem view: ((TILE_Y, TILE_X), (M/TILE_Y, N/TILE_X)). + gX_tiled = cute.zipped_divide(tma_src, (TILE_Y, TILE_X)) + + # Partition sX/gX for the TMA atom (single-CTA, no cluster/multicast). + tXsX, tXgX = cute.nvgpu.cpasync.tma_partition( + tma_atom, + 0, # Use the only CTA to do the TMA copy + cute.make_layout(1), # This cluster only has 1 CTAs + sX, + gX_tiled, + ) + + # Same partitioning for S2G outputs: sO_row → mO_row and sO_col → mO_col. + if cutlass.const_expr(cfg.ROWWISE): + gO_row_tiled = cute.zipped_divide(tma_dst_out_row, (TILE_Y, TILE_X)) + tXsO_row, tXgO_row = cute.nvgpu.cpasync.tma_partition( + tma_atom_out_row, + 0, + cute.make_layout(1), + sO_row, + gO_row_tiled, + ) + if cutlass.const_expr(cfg.COLUMNWISE): + gO_col_tiled = cute.zipped_divide(tma_dst_out_col, (TILE_Y, TILE_X)) + tXsO_col, tXgO_col = cute.nvgpu.cpasync.tma_partition( + tma_atom_out_col, + 0, + cute.make_layout(1), + sO_col, + gO_col_tiled, + ) + + # print(f"sX: {sX}\n") + # print(f"gX_tiled: {gX_tiled}\n") + # print(f"tXsX: {tXsX}\n") + # print(f"tXgX: {tXgX}\n") + + # Ensure barrier init is visible to all threads before the pipeline is used. + cute.arch.sync_threads() + + # ---- Producer: warp 0 issues one TMA copy per tile. ---- + if warp_idx == 0: + for stage in cutlass.range(num_tiles, unroll=1): + mainloop_pipeline.producer_acquire(prod_state) + tile_y = bidy * NUM_TILES + stage + cute.copy( + tma_atom, + tXgX[(None, (tile_y, bidx))], + tXsX[(None, prod_state.index)], + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(prod_state), + ) + mainloop_pipeline.producer_commit(prod_state) + prod_state.advance() + + # Per-thread amax accumulator across all stages of this CTA. Combined + # with the per-warp redux + cross-warp shmem reduce + atomic at the + # bottom to produce a global max(|x|) in mAmax. Initialised to 0 + # since amax is non-negative. + if cutlass.const_expr(cfg.WITH_AMAX): + block_amax = Float32(0.0) + + # ---- Consumer: all threads quantize each completed tile. ---- + for stage in cutlass.range(num_tiles, unroll=1): + mainloop_pipeline.consumer_wait(cons_state) + sX_tile = sX[(None, stage)] # (TILE_Y, TILE_X) bf16 + + """ + grid = [ + cute.ceil_div(Int32(N), TILE_X), + cute.ceil_div(M, TILE_Y * NUM_TILES), + ] + So to obtain the tile that belongs to this CTA. + """ + # This is just block's x axis idx + tile_idx_x = bidx + # Each CTA has `NUM_TILES` tiles. Each stage we need to obtain the tile for that specific stage. + # So the tile index along Y dimension is `bidy * NUM_TILES + stage` + tile_idx_y = bidy * NUM_TILES + stage + + if cutlass.const_expr(cfg.COLUMNWISE): + # The first row that belongs to this CTA. Each CTA handles NUM_TILES of (TILE_Y, TILE_X) tiles stacked vertically, + # and each stage handles one of them. + sO_col_tile = sO_col[(None, stage)] + mS_col_stage = cute.flatten(mS_col[(None, (tile_idx_y, tile_idx_x))]) + + amax_c = self._process_colwise( + sX_tile, sO_col_tile, + mS_col_stage, + tile_idx_y * TILE_Y, bidx * TILE_X, M, N, + ) + + if cutlass.const_expr(cfg.ROWWISE): + sO_row_tile = sO_row[(None, stage)] + # mS_row is ((SCALE_TILE), (GRID)) where SCALE_TILE = (32, 2). + # Each CTA owns NUM_TILES consecutive row-tiles of GRID. cute + # auto-decomposes the flat row coord `bidy * NUM_TILES + stage` + # onto GRID's hierarchical row modes — which is the + # (i_hi, tile_Y) tile-major order for swizzled, and the plain + # row-tile order for compact. Same source, both layouts correct. + mS_row_stage = cute.flatten(mS_row[(None, (tile_idx_y, tile_idx_x))]) + # print(f"s0_row_tile: {sO_row_tile}\n") + # print(f"sO_row: {sO_row}\n") + # print(f"mS_row: {mS_row}\n") + # print(f"mS_row_stage: {mS_row_stage}\n") + # print(f"mS_row_stage: {mS_row_stage}\n") + amax_r = self._process_rowwise( + sX_tile, sO_row_tile, + mS_row_stage, + tile_idx_y * TILE_Y, bidx * TILE_X, M, N, + ) + + # Make all smem stores (sO_row and/or sO_col) visible to the TMA + # async proxy, then block-sync so warp 0 sees the fences from all + # warps before issuing the bulk store(s). Matches the C++ + # reference's fence_proxy + __syncthreads pattern. + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + cute.arch.sync_threads() + + if warp_idx == 0: + tile_y = bidy * NUM_TILES + stage + if cutlass.const_expr(cfg.ROWWISE): + cute.copy( + tma_atom_out_row, + tXsO_row[(None, stage)], + tXgO_row[(None, (tile_y, bidx))], + ) + if cutlass.const_expr(cfg.COLUMNWISE): + cute.copy( + tma_atom_out_col, + tXsO_col[(None, stage)], + tXgO_col[(None, (tile_y, bidx))], + ) + cute.arch.cp_async_bulk_commit_group() + + mainloop_pipeline.consumer_release(cons_state) + cons_state.advance() + + # Wait for in-flight TMA stores so data is visible to the host + # before the kernel returns. + cute.arch.cp_async_bulk_wait_group(0, read=False) + + # ---- amax block reduction + cross-CTA atomic ---------------------- + # 1) intra-warp: redux.sync.fmax.f32 (sm_80+, single instruction). + # 2) cross-warp: NUM_WARPS shmem floats + sync_threads. + # 3) cross-CTA: int-atomic-max on the f32 bit pattern. Since amax is + # always ≥ 0, IEEE-754 bit ordering on positives matches float + # magnitude ordering, so atomic_max on i32 bits gives the right + # result. (atomic_max_float32 also exists but its pointer + # normalisation is broken as of this CuTeDSL build.) + if cutlass.const_expr(cfg.WITH_AMAX): + warp_amax = cute.arch.warp_redux_sync(block_amax, kind="fmax") + sAmax = storage.sAmax.get_tensor(cute.make_layout(NUM_WARPS)) + lane_idx = tidx % 32 + if lane_idx == 0: + sAmax[warp_idx] = warp_amax + cute.arch.sync_threads() + if tidx == 0: + cta_amax = Float32(0.0) + for w in cutlass.range_constexpr(NUM_WARPS): + cta_amax = cute.arch.fmax(cta_amax, sAmax[w]) + amax_i32 = cute.make_tensor( + cute.recast_ptr(mAmax.iterator, dtype=Int32), + cute.make_layout(1), + ) + cute.arch.atomic_max( + amax_i32.iterator, _bitcast_f32_to_i32(cta_amax), + ) + + + @cute.jit + def _process_rowwise( + self, + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + tile_row_start, # Int32 — global row of this stage's row 0 + tile_col_start, # Int32 — global col of this CTA's col 0 + M, N, # Int32 — full input extents, for OOB masking + ): + """Rowwise MXFP8 pass: thread `(tid_Y, tid_X) = (tidx % 32, tidx // 32)` + owns one 32-element scale block (row `tid_Y`, columns `tid_X*32 .. +32`). + + The bank-group swizzle `((w + bank_group) * PACK_SIZE) % SCALE_DIM` + staggers each 4-thread group's starting wave, which otherwise would + collide on smem banks since all lanes in a warp read different rows + at the same column offset. + + Writes quantized bytes into `sO_row_tile` as u32s (one per wave); + caller is responsible for the TMA S2G flush. + """ + cfg = self.cfg + max_norm_rcp = FP8E4M3_MAX_NORM_RCP + if cutlass.const_expr(cfg.DTYPE_ROW == "e5m2"): + max_norm_rcp = FP8E5M2_MAX_NORM_RCP + return quantize_rowwise_mxfp8( + sX_tile, + sO_row_tile, + mS_row_stage, + max_norm_rcp, + tile_row_start, + tile_col_start, + M, + N, + ACTIVATION=None, + DTYPE=cfg.DTYPE, + FP8_DTYPE=cfg.DTYPE_ROW, + TILE_Y=TILE_Y, + SCALE_DIM=SCALE_DIM, + WAVES=WAVES, + THREADS_PER_WARP=THREADS_PER_WARP, + THREADS_PER_BANK=THREADS_PER_BANK, + PACK_SIZE=PACK_SIZE + ) + + @cute.jit + def _process_colwise( + self, + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + tile_row_start, # Int32 — global row of this stage's row 0 + tile_col_start, # Int32 — global col of this CTA's col 0 + M, N, # Int32 — full input extents, for OOB masking + ): + """Colwise MXFP8 pass: thread `tidx` owns column `tidx` of the (32, 64) + smem tile — 32 elements down. Writes quantized bytes into `sO_col_tile` + so the caller can flush with a TMA S2G — matches C++'s + `out_colwise_data_sh` + `cp.async.bulk.tensor.2d.shared_to_global`. + """ + cfg = self.cfg + max_norm_rcp = FP8E4M3_MAX_NORM_RCP + if cutlass.const_expr(cfg.DTYPE_COLUMN == "e5m2"): + max_norm_rcp = FP8E5M2_MAX_NORM_RCP + return quantize_colwise_mxfp8( + sX_tile, + sO_col_tile, + mS_col_stage, + max_norm_rcp, + tile_row_start, + tile_col_start, + M, N, + ACTIVATION=None, + DTYPE=cfg.DTYPE, + FP8_DTYPE=cfg.DTYPE_COLUMN, + SWIZZLE=cfg.WITH_GEMM_SWIZZLED_SCALES, + TILE_X=TILE_X, + TILE_Y=TILE_Y, + SCALE_DIM=SCALE_DIM, + ) + +def _cfg_to_fn_name(cfg, M, N) -> str: + """Deterministic registry key from (cfg, shape).""" + key = (cfg.DTYPE.__name__, cfg.DTYPE_ROW, cfg.DTYPE_COLUMN, + int(cfg.ROWWISE), int(cfg.COLUMNWISE), + int(cfg.WITH_GEMM_SWIZZLED_SCALES), int(cfg.WITH_AMAX), + cfg.ACTIVATION or "none", + M, N) + h = hashlib.sha1(repr(key).encode()).hexdigest()[:16] + return f"mxfp8_{h}" + +_compile_cache_tvm_ffi: dict = {} + +def _get_compiled_kernel(cfg, M, N): + """Compile the kernel for THIS (cfg, M, N) with LITERAL shapes — every + dimension is a constexpr int, so the AOT wrapper's per-arg type collapses + to `{ void* data; }` (no shape array, no shape check at call time). + + Tradeoff vs sym_int: one compile per (cfg, M, N) instead of one per cfg. + Memory cost is small; the per-call saving is ~7-8 us. Cache key already + includes (M, N) so we never recompile.""" + cache = _compile_cache_tvm_ffi + fn_name = _cfg_to_fn_name(cfg, M, N) + if fn_name in cache: + return cache[fn_name], fn_name + + kernel_obj = MXFP8QuantizeKernel(cfg) + + # TE allocates scale tensors at this padded shape regardless of swizzle + # (see MXFP8Quantizer::get_scale_shape in transformer_engine/pytorch/csrc): + # rowwise: (roundup(M, 128), roundup(N // 32, 4)) + # columnwise: (roundup(M // 32, 4), roundup(N, 128)) + SCALE_R = (((M + 127) // 128) * 128, ((N + 127) // 128) * 4) + SCALE_C = (((M + 127) // 128) * 4, ((N + 127) // 128) * 128) + WS_M = (M + TILE_Y * NUM_TILES - 1) // (TILE_Y * NUM_TILES) + + # stride_order=(1, 0): row-major, dim 1 stride 1. 1D: (0,). + kw_rm16_2d = dict(stride_order=(1, 0), + memspace=cute.AddressSpace.gmem, assumed_align=16) + kw_rm4_2d = dict(stride_order=(1, 0), + memspace=cute.AddressSpace.gmem, assumed_align=4) + kw_rm4_1d = dict(stride_order=(0,), + memspace=cute.AddressSpace.gmem, assumed_align=4) + def fake(dtype, shape, kw): + return cute.runtime.make_fake_compact_tensor(dtype, shape, **kw) + + in_fake = fake(cfg.DTYPE, (M, N), kw_rm16_2d) + out_row_fake = fake(cute.Uint8, (M, N), kw_rm16_2d) if cfg.ROWWISE else None + scale_row_fake = fake(cute.Uint8, SCALE_R, kw_rm16_2d) if cfg.ROWWISE else None + out_col_fake = fake(cute.Uint8, (M, N), kw_rm16_2d) if cfg.COLUMNWISE else None + scale_col_fake = fake(cute.Uint8, SCALE_C, kw_rm16_2d) if cfg.COLUMNWISE else None + # No amax / no SR for the MXFP8 path: these slots are None. The kernel's + # __call__ takes them as Optional and dead-strips the amax/SR branches. + amax_row_fake = None + amax_col_fake = None + rng_state_fake = None + # Explicit stream arg (kept in the tvm-ffi signature, not env-stream): C++ + # passes the CUDA stream handle as an int64 scalar, decoded as int-as-ptr. + stream_fake = cute.runtime.make_fake_stream() + + compiled = cute.compile( + kernel_obj, + in_fake, # mX + out_row_fake, scale_row_fake, amax_row_fake, # mO_row, mS_row, mA_row + out_col_fake, scale_col_fake, amax_col_fake, # mO_col, mS_col, mA_col + rng_state_fake, # rng_state + stream_fake, # stream + options="--enable-tvm-ffi", + ) + cache[fn_name] = compiled + return compiled, fn_name + + +def get_mxfp8_quantizer( + x: torch.Tensor, + dtype_row: Literal["e4m3", "e5m2", "none"] = "e4m3", + dtype_col: Literal["e4m3", "e5m2", "none"] = "e4m3", + with_gemm_swizzled_scales: bool = False, +) -> FlexQuantizer: + """Compile + register the MXFP8 kernel and return a + FlexQuantizer wired to it. + + The compiled CuTeDSL function is registered into the tvm-ffi global + registry under its deterministic name; the returned quantizer carries that + name in ``quantize_func`` so the C++ ``FlexQuantizer::quantize`` can + resolve and dispatch to it via ``Function::GetGlobalRequired``. + + Note on ``with_gemm_swizzled_scales``: when True, the scale tensors are + emitted in the cuBLAS GEMM-swizzled layout, which pads them up to whole + 128x4 tiles. Because the kernel only writes the valid blocks, the quantizer + zeros the scale buffers on every quantize (FlexQuantizer::quantize issues a + cudaMemsetAsync before dispatch) so the padded entries cuBLAS reads are 0. + With swizzle=False no swizzling/zeroing is done and the scale padding is + undefined. + """ + M, N = x.shape + cutlass_dtype = torch_to_cutlass_dtype(x.dtype) + cfg = MXFP8QuantizeConfig( + cutlass_dtype, + dtype_row=dtype_row, + dtype_column=dtype_col, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + compiled, fn_name = _get_compiled_kernel(cfg, M, N) + _tvm_ffi.register_global_func(fn_name, compiled, override=True) + + quantizer = FlexQuantizer( + dtype_row=str_to_te_dtype(dtype_row), + dtype_column=str_to_te_dtype(dtype_col), + quantize_func=fn_name, + # Dequant isn't implemented for this storage-only milestone, but + # FlexQuantizer::quantize (C++) rejects an empty dequantize_func, so + # carry a placeholder name. It is never resolved/called by quantize(). + dequantize_func=f"{fn_name}_dequant", # TODO: this is fake, implement dequant and remove the placeholder + stochastic_rounding=False, + ) + # Storage-only milestone: no full PyTorch-tensor compatibility / GEMM yet. + quantizer.internal = True + quantizer.optimize_for_gemm = with_gemm_swizzled_scales + return quantizer \ No newline at end of file diff --git a/transformer_engine/common/cutedsl/cast/quantization_utils.py b/transformer_engine/common/cutedsl/cast/quantization_utils.py new file mode 100644 index 0000000000..785185efaa --- /dev/null +++ b/transformer_engine/common/cutedsl/cast/quantization_utils.py @@ -0,0 +1,1125 @@ +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass import Float32, Int64, Int32, Int16, Uint8, Uint32 +from cutlass._mlir.dialects import arith as mlir_arith +from cutlass._mlir.dialects import llvm +from cutlass.base_dsl.compiler import GPUArch +from cutlass.cute.runtime import make_ptr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass.cute.arch import cvt_f32_bf16 + +from types import SimpleNamespace + +# FP8E4M3 max representable value +FP8E4M3_MAX_NORM = 448.0 +FP8E4M3_MAX_NORM_RCP = 1.0 / FP8E4M3_MAX_NORM +FP8E5M2_MAX_NORM = 57344.0 +FP8E5M2_MAX_NORM_RCP = 1.0 / FP8E5M2_MAX_NORM + +# NVFP4 (fp4e2m1) — 4-bit float, max representable value is 6.0 +FP4_E2M1_MAX = 6.0 +FP4_E2M1_MAX_RCP = 1.0 / FP4_E2M1_MAX +# Largest finite f32 — used to clamp the per-block scale inverse against +# division-by-zero (which produces +inf and then NaN downstream). +FP32_MAX = 3.4028234663852886e38 + +FP32_MANTISSA_BITS = 23 + + +@dsl_user_op +def _bitcast_f32_to_i32(val: Float32, *, loc=None, ip=None) -> Int32: + return Int32(mlir_arith.bitcast(T.i32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def _bitcast_i32_to_f32(val: Int32, *, loc=None, ip=None) -> Float32: + return Float32(mlir_arith.bitcast(T.f32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def fabs_f32(val: Float32, *, loc=None, ip=None) -> Float32: + val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) + abs_i32 = val_i32 & Int32(0x7FFFFFFF) + return _bitcast_i32_to_f32(abs_i32, loc=loc, ip=ip) + + +@dsl_user_op +def float_to_e8m0(val: Float32, *, loc=None, ip=None) -> Int32: + """Branchless float->E8M0: add mantissa mask to round up, clamp to 254.""" + val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) + rounded = val_i32 + Int32(0x7FFFFF) + exponent = (rounded >> Int32(FP32_MANTISSA_BITS)) & Int32(0xFF) + return Int32(mlir_arith.minsi( + exponent.ir_value(loc=loc, ip=ip), + Int32(254).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def exp2f_rcp(biased_exp: Int32, *, loc=None, ip=None) -> Float32: + """2^(127 - biased_exp) with special-case handling.""" + new_exp = (Int32(254) - biased_exp) << Int32(FP32_MANTISSA_BITS) + result = _bitcast_i32_to_f32(new_exp, loc=loc, ip=ip) + for (cmp_val, repl_bits) in [(255, 0x7FFFFFFF), (254, 0x00400000), (0, 0x7F000000)]: + cond = mlir_arith.cmpi(mlir_arith.CmpIPredicate.eq, + biased_exp.ir_value(loc=loc, ip=ip), + Int32(cmp_val).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + alt = _bitcast_i32_to_f32(Int32(repl_bits), loc=loc, ip=ip) + result = Float32(mlir_arith.select( + cond, alt.ir_value(loc=loc, ip=ip), + result.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return result + + +@dsl_user_op +def cvt_f32_to_fp8e4m3(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e4m3fn via PTX cvt.rn.satfinite.e4m3x2.f32.""" + zero = Float32(0.0) + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + result_i32 = Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return result_i32 & Int32(0xFF) + + +@dsl_user_op +def cvt_f32_to_fp8e5m2(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e5m2 via PTX cvt.rn.satfinite.e5m2x2.f32.""" + zero = Float32(0.0) + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + result_i32 = Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return result_i32 & Int32(0xFF) + + +@dsl_user_op +def fma_f32(a: Float32, b: Float32, c: Float32, *, loc=None, ip=None) -> Float32: + """`fma.rn.f32 d, a, b, c;` — single-instruction fused multiply-add + matching nvcc's FFMA. Used for explicit `partial += a * b` patterns + where we need the same rounding as TE's compiler-fused FFMA.""" + return Float32(llvm.inline_asm( + T.f32(), + [a.ir_value(loc=loc, ip=ip), + b.ir_value(loc=loc, ip=ip), + c.ir_value(loc=loc, ip=ip)], + "fma.rn.f32 $0, $1, $2, $3;", + "=f,f,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +@dsl_user_op +def tanh_approx(val: Float32, *, loc=None, ip=None) -> Float32: + """`tanh.approx.f32` — fast tanh approximation. Matches CUDA `__tanhf`.""" + return Float32(llvm.inline_asm( + T.f32(), + [val.ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +@dsl_user_op +def pack_f32x2(lo: Float32, hi: Float32, *, loc=None, ip=None) -> Int64: + """Pack two f32 scalars into a single 64-bit register (`floatx2` layout). + + Low 32 bits = `lo`, high 32 bits = `hi`. Uses `mov.b64 %dst, {%lo, %hi};` + which lowers to a single register move — no actual memory traffic. + """ + return Int64(llvm.inline_asm( + T.i64(), + [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], + "mov.b64 $0, {$1, $2};", + "=l,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +@dsl_user_op +def pack_i32x2(lo: Int32, hi: Int32, *, loc=None, ip=None) -> Int64: + """i32 sibling of `pack_f32x2` — concat two i32 into a single b64 register. + Used by NVFP4 to glue two `(bf16,bf16)`/`(f16,f16)` Int32 packs into the + `Int64` operand the `mul_cvt.*x4` PTX expects.""" + return Int64(llvm.inline_asm( + T.i64(), + [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], + "mov.b64 $0, {$1, $2};", + "=l,r,r", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +@dsl_user_op +def _trunc_i32_to_i16(val: Int32, *, loc=None, ip=None) -> Int16: + """Narrow an Int32 to Int16 by keeping the low 16 bits. + + Lives here because the existing arith-dialect narrowing pattern requires + loc/ip kwargs (see other `mlir_arith.trunci` callers); wrapping it as a + `@dsl_user_op` lets `@cute.jit` bodies use it without plumbing those in.""" + return Int16(mlir_arith.trunci( + T.i16(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def cvt_fp8e4m3_to_f32(byte_i32: Int32, *, loc=None, ip=None) -> Float32: + """One fp8e4m3 byte (low 8 bits of `byte_i32`) → f32. + + PTX has no direct `cvt.f32.e4m3` for a scalar; route through the packed + `cvt.rn.f16x2.e4m3x2` and then `cvt.f32.f16`. The high byte of the .b16 + register is forced to zero so the discarded high f16 lane is well-defined.""" + asm = ( + "{\n" + ".reg .b32 masked; .reg .b16 b16; .reg .b16 b16_hi;\n\t" + ".reg .b32 f16pair; .reg .b16 lo_f16; .reg .b16 hi_f16;\n\t" + "and.b32 masked, $1, 0xFF;\n\t" + "mov.b32 {b16, b16_hi}, masked;\n\t" + "cvt.rn.f16x2.e4m3x2 f16pair, b16;\n\t" + "mov.b32 {lo_f16, hi_f16}, f16pair;\n\t" + "cvt.f32.f16 $0, lo_f16;\n\t" + "}" + ) + return Float32(llvm.inline_asm( + T.f32(), + [byte_i32.ir_value(loc=loc, ip=ip)], + asm, + "=f,r", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + +# --------------------------------------------------------------------------- +# 16-bit packed input PTX kit (bf16 / f16) +# +# bf16 and f16 share the same fast-path shape: packed-x2 amax via +# `max.xorsign.abs.x2`, then per-lane widen-to-f32 + `mul.f32x2` + +# `cvt.rn.satfinite.x2.f32`. Only the opcodes differ. Build one PTX kit +# per format at module load and let the kernel pick the right kit at JIT +# trace time via `cfg.DTYPE` — equivalent to a C++ template arg specialization +# on `IType`, with no runtime branch. +# --------------------------------------------------------------------------- +def _build_packed16_kit(in_fmt: str): + """Build a kit of PTX wrappers for a 16-bit input format. + + `in_fmt` is the PTX format string ('bf16' or 'f16'). Returns a namespace + with the per-format ops the rowwise/colwise inner loops need: + + abs_max_x2(Int32, Int32) -> Int32 # `max.xorsign.abs.x2` + abs_max_scalar(Int16, Int16) -> Int16 # `max.xorsign.abs.` + bits_to_f32(Int16) -> Float32 # widen one 16-bit element + x2_lo_to_f32(Int32) -> Float32 # extract+widen low half + x2_hi_to_f32(Int32) -> Float32 # extract+widen high half + mul_cvt_to_fp8x2(fp8_dtype) -> callable(Int32, Int64)->Int32 + # fused x2 * f32x2 -> fp8x2 + """ + + @dsl_user_op + def abs_max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: + return Int32(llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + @dsl_user_op + def max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: + return Int32(llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + @dsl_user_op + def abs_max_scalar(a: Int16, b: Int16, *, loc=None, ip=None) -> Int16: + return Int16(llvm.inline_asm( + T.i16(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt} $0, $1, $2;", + "=h,h,h", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + if in_fmt == "bf16": + # bf16 == top 16 bits of f32 — widening is a free bit-shift. + @dsl_user_op + def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: + i32 = Int32(mlir_arith.extui( + T.i32(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + return _bitcast_i32_to_f32( + (bits & Int32(0xFFFF)) << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + # `(x >> 16) << 16` ≡ `x & 0xFFFF0000`, sidestepping signed-literal + # issues. Sign bits from the arith-right shift get zeroed by the + # left shift. + return _bitcast_i32_to_f32( + (bits >> Int32(16)) << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: + """Round f32 to bf16 precision (round-to-nearest-even), keep f32. + Matches C++'s `static_cast(static_cast(elt))`.""" + bf16_bits = Int16(llvm.inline_asm( + T.i16(), [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.bf16.f32 $0, $1;", + "=h,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + i32 = Int32(mlir_arith.extui( + T.i32(), bf16_bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) + else: + # f16 has its own bit layout; widening requires `cvt.f32.f16`. + @dsl_user_op + def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: + return Float32(llvm.inline_asm( + T.f32(), [bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + @dsl_user_op + def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + lo_i16 = Int16(mlir_arith.trunci( + T.i16(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return bits_to_f32(lo_i16, loc=loc, ip=ip) + + @dsl_user_op + def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + hi_shifted = bits >> Int32(16) + hi_i16 = Int16(mlir_arith.trunci( + T.i16(), hi_shifted.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return bits_to_f32(hi_i16, loc=loc, ip=ip) + + @dsl_user_op + def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: + """Round f32 to f16 precision, keep f32.""" + f16_bits = Int16(llvm.inline_asm( + T.i16(), [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.f16.f32 $0, $1;", + "=h,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Float32(llvm.inline_asm( + T.f32(), [f16_bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + def _build_mul_cvt(out_fmt: str, relu: bool = False): + """Build a fused `x2 * f32x2 → fp8x2` PTX wrapper. + + The shape is identical across (in_fmt, out_fmt) combos — only the + widening opcode (`cvt.f32.`) and the final saturating cvt + (`cvt.rn.satfinite.x2.f32`) differ. + """ + out_op = "e4m3x2" if out_fmt == "e4m3" else "e5m2x2" + asm = ( + "{\n" + ".reg.b64 vp0; .reg.b64 vp1;\n\t" + ".reg.b32 v1; .reg.b32 v2;\n\t" + ".reg.b16 vb1; .reg.b16 vb2;\n\t" + "mov.b32 {vb1, vb2}, $1;\n\t" + f"cvt.f32.{in_fmt} v1, vb1;\n\t" + f"cvt.f32.{in_fmt} v2, vb2;\n\t" + "mov.b64 vp0, {v1, v2};\n\t" + "mul.f32x2 vp1, vp0, $2;\n\t" + "mov.b64 {v2, v1}, vp1;\n\t" + f"cvt.rn.satfinite{".relu" if relu else ""}.{out_op}.f32 $0, v1, v2;\n\t" + "}" + ) + + @dsl_user_op + def fn(val_2x: Int32, scale_2x: Int64, *, loc=None, ip=None) -> Int32: + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_2x.ir_value(loc=loc, ip=ip), + scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=h,r,l", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return fn + + def mul_cvt_to_fp8x2(fp8_dtype: str, relu: bool = False): + if fp8_dtype == "e5m2": + return _build_mul_cvt("e5m2", relu) + return _build_mul_cvt("e4m3", relu) + + # NVFP4 fused cast: x4 × f32x2 → fp4e2m1x4 (4 fp4 packed in 16 + # bits). Same shape as `mul_cvt_to_fp8x2` but produces 4 elements at a + # time because the `cvt.rn.satfinite.e2m1x2.f32` PTX consumes pairs and + # writes a single byte (high nibble = first input, low nibble = second). + # The shuffled `mov.b64 {v1, v0}, v01` lines after the muls undo the + # PTX's hi/lo packing so the resulting byte is naturally + # `(fp4(elt1) << 4) | fp4(elt0)` — matches TE's C++ asm. + @dsl_user_op + def mul_cvt_to_fp4x4(in_4x: Int64, scale_2x: Int64, *, loc=None, ip=None) -> Int32: + asm = ( + "{\n" + ".reg.b64 v01; .reg.b64 v23;\n\t" + ".reg.b16 i0; .reg.b16 i1; .reg.b16 i2; .reg.b16 i3;\n\t" + ".reg.b32 v0; .reg.b32 v1; .reg.b32 v2; .reg.b32 v3;\n\t" + ".reg.b8 f0; .reg.b8 f1;\n\t" + "mov.b64 {i0, i1, i2, i3}, $1;\n\t" + f"cvt.f32.{in_fmt} v0, i0;\n\t" + f"cvt.f32.{in_fmt} v1, i1;\n\t" + f"cvt.f32.{in_fmt} v2, i2;\n\t" + f"cvt.f32.{in_fmt} v3, i3;\n\t" + "mov.b64 v01, {v0, v1};\n\t" + "mov.b64 v23, {v2, v3};\n\t" + "mul.f32x2 v01, v01, $2;\n\t" + "mul.f32x2 v23, v23, $2;\n\t" + "mov.b64 {v1, v0}, v01;\n\t" + "mov.b64 {v3, v2}, v23;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 $0, {f0, f1, f0, f1};\n\t" + "}" + ) + return Int32(llvm.inline_asm( + T.i32(), + [in_4x.ir_value(loc=loc, ip=ip), scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=r,l,l", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + return SimpleNamespace( + abs_max_x2=abs_max_x2, + max_x2=max_x2, + abs_max_scalar=abs_max_scalar, + bits_to_f32=bits_to_f32, + x2_lo_to_f32=x2_lo_to_f32, + x2_hi_to_f32=x2_hi_to_f32, + truncate_f32=truncate_f32, + mul_cvt_to_fp8x2=mul_cvt_to_fp8x2, + mul_cvt_to_fp4x4=mul_cvt_to_fp4x4, + ) + + +_BF16_KIT = _build_packed16_kit("bf16") +_F16_KIT = _build_packed16_kit("f16") + + +def _is_packed16(dtype) -> bool: + """True if `dtype` is one of the 16-bit packed input formats.""" + return dtype is cutlass.BFloat16 or dtype is cutlass.Float16 + + +def _packed16_kit(dtype): + """Trace-time selector — pick a Packed16Kit for the input dtype.""" + if dtype is cutlass.Float16: + return _F16_KIT + return _BF16_KIT + + +# --------------------------------------------------------------------------- +# Forward-activation registry +# +# Each entry is a Float32 → Float32 callable applied per element before the +# MXFP8 amax + cast. Selection is by Python string at JIT trace time, so the +# const-expr machinery treats `cfg.ACTIVATION` like a C++ template argument +# — no runtime branch in the inner loop, separate kernel cached per choice. +# +# Math primitives match CUDA fast-math intrinsics so outputs are bit-exact +# with PyTorch's CUDA implementations of the same activations: +# tanh -> tanh.approx.f32 (== __tanhf) +# exp(x) -> exp2.approx.f32(x · log2(e)) (== __expf) +# --------------------------------------------------------------------------- +def _act_relu(x: Float32) -> Float32: + return cute.arch.fmax(x, Float32(0.0)) + + +def _act_gelu(x: Float32) -> Float32: + """Tanh-approximation GELU. Constants and operator grouping match TE's + `transformer_engine/common/util/math.h::gelu` exactly (factored form + `x · (0.5 + 0.5·tanh(x·(a + b·x²)))`) so quantized output is bit-exact + against the C++ fused IS_ACT path. Uses `cute.math.tanh(fastmath=False)` + rather than the `tanh.approx.f32` PTX intrinsic — TE compiles activation + kernels without `--use_fast_math` by default, so its `tanhf` is the + IEEE-precise expansion.""" + A = Float32(0.79788456) # sqrt(2/π) truncated to TE's 8-digit literal + B = Float32(0.03567741) # = sqrt(2/π) · 0.044715, same truncation + return x * (Float32(0.5) + Float32(0.5) * cute.math.tanh(x * (A + B * x * x))) + + +def _act_silu(x: Float32) -> Float32: + """SiLU/Swish: x · σ(x) = x / (1 + e^-x). + Matches TE's `silu` (`val / (1 + expf(-val))`).""" + return x / (Float32(1.0) + cute.arch.exp(-x)) + + +_ACTIVATIONS = { + "relu": _act_relu, + "gelu": _act_gelu, + "silu": _act_silu, +} + + +@dsl_user_op +def cvt_f32x2_to_fp8e4m3x2(val_hi: Float32, val_lo: Float32, relu: bool = False, + *, loc=None, ip=None) -> Int32: + """Convert two float32 values to two packed fp8e4m3fn bytes in one instruction. + + Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). + This mirrors ptx::mul_cvt_2x which converts 2 values in one instruction. + """ + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def cvt_f32x2_to_fp8e5m2x2(val_hi: Float32, val_lo: Float32, relu: bool = False, + *, loc=None, ip=None) -> Int32: + """e5m2 sibling of `cvt_f32x2_to_fp8e4m3x2`.""" + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def mul_cvt_f32x4_to_fp4x4(in01: Int64, in23: Int64, scale_2x: Int64, + *, loc=None, ip=None) -> Int32: + """f32x4 sibling of `kit.mul_cvt_to_fp4x4` — for the NVFP4 colwise path + where elements live on a strided column and we've already widened to f32 + for the amax reduction. `in01` = pack(f32_0, f32_1), `in23` similarly.""" + asm = ( + "{\n" + ".reg.b64 v01; .reg.b64 v23;\n\t" + ".reg.b32 v0; .reg.b32 v1; .reg.b32 v2; .reg.b32 v3;\n\t" + ".reg.b8 f0; .reg.b8 f1;\n\t" + "mov.b64 {v0, v1}, $1;\n\t" + "mov.b64 {v2, v3}, $2;\n\t" + "mov.b64 v01, {v0, v1};\n\t" + "mov.b64 v23, {v2, v3};\n\t" + "mul.f32x2 v01, v01, $3;\n\t" + "mul.f32x2 v23, v23, $3;\n\t" + "mov.b64 {v1, v0}, v01;\n\t" + "mov.b64 {v3, v2}, v23;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 $0, {f0, f1, f0, f1};\n\t" + "}" + ) + return Int32(llvm.inline_asm( + T.i32(), + [in01.ir_value(loc=loc, ip=ip), + in23.ir_value(loc=loc, ip=ip), + scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=r,l,l,l", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +def _cvt_f32_to_fp8(fp8_dtype: str): + """Const-expr dispatch: pick the f32→fp8 scalar PTX op based on output dtype. + + `fp8_dtype` is the Python string from `cfg.FP8_DTYPE`, evaluated at JIT + trace time; the unused branch is never traced. + """ + if fp8_dtype == "e5m2": + return cvt_f32_to_fp8e5m2 + return cvt_f32_to_fp8e4m3 + + +def _cvt_f32x2_to_fp8x2(fp8_dtype: str): + """Const-expr dispatch for the packed f32x2→fp8x2 cvt.""" + if fp8_dtype == "e5m2": + return cvt_f32x2_to_fp8e5m2x2 + return cvt_f32x2_to_fp8e4m3x2 + +@cute.jit +def quantize_rowwise_mxfp8( + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). Same purpose. + M, N, # Int32 — full tensor extents; OOB threads skip their + # scale store. + ACTIVATION, + DTYPE, + FP8_DTYPE, + TILE_Y, + SCALE_DIM, + WAVES, + THREADS_PER_WARP, + THREADS_PER_BANK, + PACK_SIZE, +): + tidx, _, _ = cute.arch.thread_idx() + + # Match the C++ reference's thread layout: pairs of adjacent lanes + # share a row (lanes 2k / 2k+1 both own row k), each pair covering + # the two 32-element scale blocks of that row. Express as a cute + # layout mapping `(tid_Y, tid_X) -> tidx` with stride (2, 1): + # linear(tidx) = tid_Y*2 + tid_X, so `get_flat_coord` inverts to + # `(tidx // 2, tidx % 2)` — semantically clearer than the raw + # divmod, and readily reusable if we later partition via TiledCopy. + # print(f"sX_tile: {sX_tile}") + # print(f"sO_row_tile: {sO_row_tile}") + # print(f"mS_row_stage: {mS_row_stage}") + + tiler, tv_layout = cute.make_layout_tv( + thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), + val_layout=cute.make_layout((1, SCALE_DIM), stride=(0, 1)) + ) + # print(f"tv_layout: {tv_layout}") + # print(f"tiler: {tiler}") + + sX_tv = cute.composition(sX_tile, tv_layout) + sO_tv = cute.composition(sO_row_tile, tv_layout) + + # I/O Elements that belong to this thread + sX_thread = sX_tv[tidx, None] # shape (32,) bf16 + sO_thread = sO_tv[tidx, None] # shape (32,) uint8 + + # See https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%2832%2C+2%29%3A%282%2C1%29-%281%2C+32%29%3A%280%2C1%29 + # print(f"sX_thread: {sX_thread}") + # print(f"sO_thread: {sO_thread}") + + sO_thread_u32_ptr = cute.recast_ptr(sO_thread.iterator, dtype=Uint32) + # Each wave it writes 32 bytes = 8 uint32s, so in 4 waves we write all 32 quantized elements. + sO_thread_u32 = cute.make_tensor( + sO_thread_u32_ptr, + cute.make_layout((SCALE_DIM // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements + ) + # print(f"sO_thread_u32: {sO_thread_u32}") + + FUSE_RELU = cutlass.const_expr(ACTIVATION == "relu") + # For this fast paht we can read in pack of 2 instead of reading individual f16 / bf16 element + _row_fast = (_is_packed16(DTYPE) and (ACTIVATION is None or FUSE_RELU)) + + if cutlass.const_expr(_row_fast): + # If no activation, f16 / bf16 and rowwise quantization, we can read 2 f16 / bf16 at once in a pack + # and use max.xorsign.abs.f16x2 / max.xorsign.abs.bf16x2 to compute + kit = _packed16_kit(DTYPE) + sX_thread_rw_i32 = cute.make_tensor( + cute.recast_ptr(sX_thread.iterator, dtype=Int32), + cute.make_layout((1, SCALE_DIM // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements + ) + # print(f"sX_thread_rw_i32: {sX_thread_rw_i32}") + # Each wave we read 2 packed i32, which is 4 fp16/bf16 elements (PACK_SIZE) + # In total we have 8 waves where each wave reads + in_r = [[None, None] for _ in range(WAVES)] + bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 2 # Each bank group will read 2 i32 from their bank + for w in cutlass.range_constexpr(WAVES): + idx = (w * 2 + offset) % (SCALE_DIM // 2) + in_r[w][0] = sX_thread_rw_i32[0, idx] + in_r[w][1] = sX_thread_rw_i32[0, idx + 1] + + # 1. Packed-x2 amax — 2 PTX per wave, 16 total per thread. + # Accumulates `|elt|` in both lanes (with xorsign-drifted signs); + # final horizontal max reduces the two lanes to a single f32. + amax_2x = Int32(0) + # Each wave will use max.xorsign.abs.f16x2 or max.xorsign.abs.bf16x2 to compare 2 packed elements in parallel + for w in cutlass.range_constexpr(WAVES): + if cutlass.const_expr(FUSE_RELU): + # If we fuse relu then we don't want to do abs since negative value will be set to 0 and they will lose comparison automatically + amax_2x = kit.max_x2(amax_2x, in_r[w][0]) + amax_2x = kit.max_x2(amax_2x, in_r[w][1]) + else: + amax_2x = kit.abs_max_x2(amax_2x, in_r[w][0]) + amax_2x = kit.abs_max_x2(amax_2x, in_r[w][1]) + if cutlass.const_expr(FUSE_RELU): + # Compare the 2 packed max without abs + amax_r = cute.arch.fmax( + kit.x2_lo_to_f32(amax_2x), + kit.x2_hi_to_f32(amax_2x), + ) + # For relu the max is at least 0 + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) + else: + # Compare the 2 packed abs max + amax_r = cute.arch.fmax( + fabs_f32(kit.x2_lo_to_f32(amax_2x)), + fabs_f32(kit.x2_hi_to_f32(amax_2x)), + ) + else: + # Since we need to do computation on individual f16 / bf16 elements, we can't read in pack + sX_thread_rw = cute.make_tensor( + sX_thread.iterator, + cute.make_layout((1, SCALE_DIM), stride=(0, 1)), + ) + in_r = [[None] * PACK_SIZE for _ in range(WAVES)] + bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 4 # Each bank group will read 4 f16 from their bank + + if cutlass.const_expr(ACTIVATION is not None): + op = _ACTIVATIONS[ACTIVATION] + + if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): + kit_act = _packed16_kit(DTYPE) + amax_r = Float32(0.0) + for w in cutlass.range_constexpr(WAVES): + idx = (w * PACK_SIZE + offset) % SCALE_DIM + for e in cutlass.range_constexpr(PACK_SIZE): + x = Float32(sX_thread_rw[0, idx + e]) + # If IS_ACT, apply activation function to x in f32 + if cutlass.const_expr(ACTIVATION is not None): + # If it's relu, we can handle it later + if not cutlass.const_expr(FUSE_RELU): + x = op(x) + # If 16-bit input with activation, truncate to IType + if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): + x = kit_act.truncate_f32(x) # TODO: Why not just qunatize from f32? + in_r[w][e] = x + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax(amax_r, x) # For relu cases, we don't need abs since negative values will be 0 so they lose comparison automatically + else: + amax_r = cute.arch.fmax(amax_r, fabs_f32(x)) + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) # If relu, the amax is at least 0 + + # 2. E8M0 scale → gmem. mS_row's layout already encodes the swizzle + # when cfg.WITH_GEMM_SWIZZLED_SCALES=True, so 2D access just works. + biased_exp_r = float_to_e8m0(amax_r * max_norm_rcp) + # mS_row_stage has logical shape (32, 2) and we have 64 threads where each is mapped to one scale factor + # The TV layout is equivalent to https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%2832%2C+2%29%3A%282%2C+1%29-%281%29 + # but it's too trival so let's just index it directly without using layout + # Note this is the logical layout, which is on top of the swizzled / non-swizzled scale factor layout that mappes the logical index to the physical offset + # Irregular shapes: skip the scale store if this thread's logical row / + # col-block lies past the input's actual extents. TMA already zero-fills + # OOB input reads and drops OOB output writes; only the direct scale-byte + # gmem store needs an explicit guard. + scale_row = tile_row_start + tidx // 2 + scale_col_first_elt = tile_col_start + (tidx % 2) * SCALE_DIM + if scale_row < M and scale_col_first_elt < N: + mS_row_stage[(tidx // 2, tidx % 2)] = Uint8(biased_exp_r) + + # 3. scale + packed fp8 cast → smem as one u32 per wave. + inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale + # Fetch the conversion function based on the FP8 format + cvt_f32x2 = _cvt_f32x2_to_fp8x2(FP8_DTYPE) + if cutlass.const_expr(_row_fast): + kit_cast = _packed16_kit(DTYPE) + mul_cvt_x2 = kit_cast.mul_cvt_to_fp8x2(FP8_DTYPE, FUSE_RELU) + # Pack `(inv_scale_r, inv_scale_r)` as a single 64-bit f32x2 once; + # the per-wave mul_cvt consumes this directly. + scale_2x = pack_f32x2(inv_scale_r, inv_scale_r) + + bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 4 # Each bank group will write 4 fp8 to + for w in cutlass.range_constexpr(WAVES): + idx = (w * 4 + offset) % SCALE_DIM + idx = idx // 4 + if cutlass.const_expr(_row_fast): + # One fused PTX per x2 pair: x2 × f32x2 → fp8x2. + # Byte layout: byte[0]=fp8(lo * s), byte[1]=fp8(hi * s). + p01 = mul_cvt_x2(in_r[w][0], scale_2x) + p23 = mul_cvt_x2(in_r[w][1], scale_2x) + else: + # cvt PTX semantics: `cvt.rn.satfinite..f32 d, a, b` gives + # d[15:8]=fp8(a), d[7:0]=fp8(b). Pass (v1, v0) so the u16 low + # byte ends up as fp8(v0) and the high byte as fp8(v1). + v0 = in_r[w][0] * inv_scale_r + v1 = in_r[w][1] * inv_scale_r + v2 = in_r[w][2] * inv_scale_r + v3 = in_r[w][3] * inv_scale_r + p01 = cvt_f32x2(v1, v0, FUSE_RELU) # u16 little-endian: v0,v1 + p23 = cvt_f32x2(v3, v2, FUSE_RELU) # u16 little-endian: v2,v3 + quad = (p23 << Int32(16)) | p01 + sO_thread_u32[idx] = Uint32(quad) + + return amax_r + +@cute.jit +def quantize_colwise_mxfp8( + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). + M, N, # Int32 — full tensor extents. + ACTIVATION, + DTYPE, + FP8_DTYPE, + SWIZZLE, + TILE_X, + TILE_Y, + SCALE_DIM, +): + tidx, _, _ = cute.arch.thread_idx() + + # print(f"sX_tile: {sX_tile}") + # print(f"sO_col_tile: {sO_col_tile}") + # print(f"mS_col_stage: {mS_col_stage}") + + tiler, tv_layout = cute.make_layout_tv( + thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), + val_layout=cute.make_layout((SCALE_DIM, 1), stride=(1, 1)) + ) + # print(f"tv_layout: {tv_layout}") + + sX_tv = cute.composition(sX_tile, tv_layout) + sO_tv = cute.composition(sO_col_tile, tv_layout) + + # I/O Elements that belong to this thread + sX_thread = sX_tv[tidx, None] + sO_thread = sO_tv[tidx, None] + + # See https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%281%2C+64%29%3A%2864%2C+1%29-%2832%2C+1%29%3A%281%2C+1%29 + # print(f"sX_thread: {sX_thread}") # shape (32,) bf16 + # print(f"sO_thread: {sO_thread}") # shape (32,) uint8 + + HALF_PRECISION_PATH = _is_packed16(DTYPE) and ACTIVATION is None + + # 0. Load the 32-element column from smem into registers once (matches + # C++'s `in_colwise_IType[i]` cache). Amax and cast both reuse these. + if cutlass.const_expr(HALF_PRECISION_PATH): + kit = _packed16_kit(DTYPE) + # Per-thread Int16 view of the column. Same byte address as + # `sX_thread` (bf16/fp16 are 16-bit, same width as Int16); the + # element stride is TILE_X because the column elements are + # TILE_X apart in the row-major tile. + sX_thread_i16 = cute.make_tensor( + cute.recast_ptr(sX_thread.iterator, dtype=Int16), + cute.make_layout((SCALE_DIM,), stride=(TILE_X,)), + ) + amax_bits = Int16(0) + for i in cutlass.range_constexpr(SCALE_DIM): + amax_bits = kit.abs_max_scalar(amax_bits, sX_thread_i16[i]) + amax_c = fabs_f32(kit.bits_to_f32(amax_bits)) + else: + # Materialize the column into f32 registers — widen on read so + # bf16/fp16 inputs become real fp32 values (a pointer recast to + # Float32 would not widen; it would reinterpret the 16-bit bytes + # as half of a 32-bit float). + sX_thread_f32 = cute.make_rmem_tensor( + layout_or_shape=cute.make_layout((SCALE_DIM,), stride=(1,)), + dtype=Float32, + ) + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = Float32(sX_thread[i]) + # Apply activation in f32 (no truncation yet) + if cutlass.const_expr(ACTIVATION is not None): + op = _ACTIVATIONS[ACTIVATION] + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = op(sX_thread_f32[i]) + # Numerical truncation through IType so amax/cast match C++. + # Only needed when 16-bit input + activation; without activation + # the widening was already exact. + if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): + kit_act = _packed16_kit(DTYPE) + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = kit_act.truncate_f32(sX_thread_f32[i]) + amax_c = Float32(0.0) + for i in cutlass.range_constexpr(SCALE_DIM): + amax_c = cute.arch.fmax(amax_c, fabs_f32(sX_thread_f32[i])) + + # 2. E8M0 scale → gmem. mS_col's layout already encodes the swizzle + # when cfg.WITH_GEMM_SWIZZLED_SCALES=True, so 2D access just works. + # Irregular shapes: skip when this stage's row range or this thread's + # column lies past the input extents. TILE_Y == SCALE_DIM so each stage + # is exactly one scale-row; valid iff `tile_row_start < M`. + biased_exp_c = float_to_e8m0(amax_c * max_norm_rcp) + scale_col = tile_col_start + tidx + if tile_row_start < M and scale_col < N: + if cutlass.const_expr(SWIZZLE): + mS_col_stage[(0, tidx % 32, tidx // 32)] = Uint8(biased_exp_c) + else: + mS_col_stage[(0, tidx)] = Uint8(biased_exp_c) + + # 3. scale + FP8 cast → smem (one byte per (row, tidx)). Caller + # flushes the whole (TILE_Y, TILE_X) tile with a TMA S2G. + inv_scale_c = exp2f_rcp(biased_exp_c) + cvt_to_fp8 = _cvt_f32_to_fp8(FP8_DTYPE) + if cutlass.const_expr(HALF_PRECISION_PATH): + kit_cast = _packed16_kit(DTYPE) + for i in cutlass.range_constexpr(SCALE_DIM): + v_f32 = kit_cast.bits_to_f32(sX_thread_i16[i]) + sO_thread[i] = Uint8(cvt_to_fp8(v_f32 * inv_scale_c)) + else: + for i in cutlass.range_constexpr(SCALE_DIM): + sO_thread[i] = Uint8(cvt_to_fp8(sX_thread_f32[i] * inv_scale_c)) + + return amax_c + + +# --------------------------------------------------------------------------- +# NVFP4 quantization (pure cast, no activations / no stochastic rounding / +# no 2D-block scaling). Layout choices below are deliberately the simplest +# thing that lets these reuse the MXFP8 input SRAM tile so a flex kernel +# can call MXFP8 rowwise/colwise *and* NVFP4 rowwise/colwise on the same +# tile in shared memory. +# +# Two-level scaling (matches TE C++): +# S_enc = 448 * 6 / global_amax (precomputed by the caller, uniform) +# S_dec = 1 / S_enc +# per scale-block of SCALE_DIM_NVFP4 elements: +# S_dec_b_fp8 = fp8e4m3(block_amax * (S_enc / 6)) → gmem +# block_scale_inv = min(1 / (f32(S_dec_b_fp8) * S_dec), FLT_MAX) +# out[i] = fp4(in[i] * block_scale_inv) → smem +# +# Output tile shapes (fp4 packs 2 per byte): +# rowwise : (TILE_Y, TILE_X // 2) uint8 — pairs along X +# colwise : (TILE_Y // 2, TILE_X ) uint8 — pairs along Y +# These shapes pair the two fp4 nibbles in the same dimension the scale +# block lives in, which keeps each thread's writes simple. The caller can +# transpose at TMA-store time if a different on-disk layout is needed. +# --------------------------------------------------------------------------- +SCALE_DIM_NVFP4 = 16 + + +@cute.jit +def quantize_rowwise_nvfp4( + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X // 2) uint8 smem view (rowwise FP4 output) + mS_row_stage, # (TILE_Y, TILE_X // SCALE_DIM) uint8 — one E4M3 byte per (row, scale-block) + S_enc, # Float32 — precomputed global encode scale (uniform across threads) + tile_row_start, # Int32 — global row index of this stage's row 0 + tile_col_start, # Int32 — global col index of this CTA's col 0 + M, N, # Int32 — full tensor extents; OOB threads skip scale store + DTYPE, + TILE_Y, + TILE_X, + SCALE_DIM, # = SCALE_DIM_NVFP4 (16); explicit for symmetry with MXFP8 fn +): + """Rowwise NVFP4 pass — reuses the MXFP8 rowwise 64-thread layout. + + Thread `(row, seg) = (tidx // 2, tidx % 2)` owns one `SEG = TILE_X // 2` + (=32) element segment of row `row`, columns `[seg*SEG : seg*SEG+SEG]`. + With MXFP8 `SCALE_DIM=32` that segment is exactly one block; with NVFP4 + `SCALE_DIM=16` it is `BLOCKS_PER_SEG = SEG // SCALE_DIM` (=2) blocks, so + we loop over them — symmetric with `quantize_colwise_nvfp4`'s per-column + `BLOCKS_PER_COL` loop. This keeps rowwise NVFP4 on the same 64-thread + CTA as every other pass instead of needing 128 threads. + """ + tidx, _, _ = cute.arch.thread_idx() + + SEG = TILE_X // 2 # 32 — elements per thread (half a row) + BLOCKS_PER_SEG = SEG // SCALE_DIM # 2 — NVFP4 scale-blocks in that segment + + # Same TV layout as MXFP8 rowwise (thr (TILE_Y, 2), val (1, SEG)); the + # output view packs 2 fp4 per byte so its val extent is SEG // 2. + _, tv_layout_in = cute.make_layout_tv( + thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), + val_layout=cute.make_layout((1, SEG), stride=(0, 1)), + ) + _, tv_layout_out = cute.make_layout_tv( + thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), + val_layout=cute.make_layout((1, SEG // 2), stride=(0, 1)), + ) + sX_tv = cute.composition(sX_tile, tv_layout_in) + sO_tv = cute.composition(sO_row_tile, tv_layout_out) + + sX_thread = sX_tv[tidx, None] # (SEG,) bf16/fp16 + sO_thread = sO_tv[tidx, None] # (SEG // 2,) uint8 + + row = tidx // 2 + seg = tidx % 2 + + # Packed views: read 2 elts/Int32, write 2 fp4-bytes/Int16. + kit = _packed16_kit(DTYPE) + sX_thread_i32 = cute.make_tensor( + cute.recast_ptr(sX_thread.iterator, dtype=Int32), + cute.make_layout((SEG // 2,), stride=(1,)), + ) + sO_thread_i16 = cute.make_tensor( + cute.recast_ptr(sO_thread.iterator, dtype=Int16), + cute.make_layout((SEG // 4,), stride=(1,)), + ) + + S_dec = Float32(1.0) / S_enc + fp4_max_inv = Float32(FP4_E2M1_MAX_RCP) + + # Loop over the BLOCKS_PER_SEG scale-blocks within this thread's segment. + # Each block: amax over SCALE_DIM elements → E4M3 scale byte → + # block_scale_inverse → SCALE_DIM/4 fp4x4 stores. Identical math to the + # colwise path, just along X with the packed-x2 amax fast read. + for blk in cutlass.range_constexpr(BLOCKS_PER_SEG): + i32_base = blk * (SCALE_DIM // 2) # 8 Int32 per block + in_r = [None] * (SCALE_DIM // 2) + for w in cutlass.range_constexpr(SCALE_DIM // 2): + in_r[w] = sX_thread_i32[i32_base + w] + + # 1. amax via packed-x2 max.xorsign.abs + amax_2x = Int32(0) + for w in cutlass.range_constexpr(SCALE_DIM // 2): + amax_2x = kit.abs_max_x2(amax_2x, in_r[w]) + amax_r = cute.arch.fmax( + fabs_f32(kit.x2_lo_to_f32(amax_2x)), + fabs_f32(kit.x2_hi_to_f32(amax_2x)), + ) + + # 2. Per-block E4M3 scale (saturating cvt → block_amax=0 ⇒ fp8 zero). + S_dec_b_f32_pre = amax_r * (S_enc * fp4_max_inv) + S_dec_b_byte = cvt_f32_to_fp8e4m3(S_dec_b_f32_pre) + + # Scale-block index within the row: seg covers blocks + # [seg*BLOCKS_PER_SEG : +BLOCKS_PER_SEG]. + sb = seg * BLOCKS_PER_SEG + blk + scale_row = tile_row_start + row + scale_col_first_elt = tile_col_start + sb * SCALE_DIM + if scale_row < M and scale_col_first_elt < N: + mS_row_stage[(row, sb)] = Uint8(S_dec_b_byte) + + # 3. block_scale_inverse = min(1 / (f32(S_dec_b_fp8) * S_dec), FLT_MAX). + S_dec_b_f32 = cvt_fp8e4m3_to_f32(S_dec_b_byte) + block_scale_inverse = cute.arch.fmin( + Float32(1.0) / (S_dec_b_f32 * S_dec), Float32(FP32_MAX)) + scale_2x = pack_f32x2(block_scale_inverse, block_scale_inverse) + + # 4. Cast SCALE_DIM elements → SCALE_DIM/4 × fp4x4 (2 bytes each). + # Block blk occupies output bytes [blk*(SCALE_DIM/2) : +SCALE_DIM/2], + # i.e. Int16 slots [blk*(SCALE_DIM/4) : +SCALE_DIM/4]. + i16_base = blk * (SCALE_DIM // 4) + for w in cutlass.range_constexpr(SCALE_DIM // 4): + packed_4 = pack_i32x2(in_r[2 * w], in_r[2 * w + 1]) + quad = kit.mul_cvt_to_fp4x4(packed_4, scale_2x) + sO_thread_i16[i16_base + w] = _trunc_i32_to_i16(quad) + + return + + +@cute.jit +def quantize_colwise_nvfp4( + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_X, TILE_Y // 2) uint8 smem view — TRANSPOSED FP4 output + # (row `tidx` = column `tidx` of input; bytes pack vertically + # adjacent input rows). Matches TE's NVFP4 columnwise data + # storage shape `(N, M // 2)`, so the caller's TMA S2G goes + # straight to the right gmem layout with no extra transpose. + mS_col_stage, # (TILE_X, TILE_Y // SCALE_DIM) uint8 — one E4M3 byte per + # (col, scale-block-y). Also transposed to match TE's NVFP4 + # columnwise scale shape `(N, M // 16)`. + S_enc, # Float32 — precomputed global encode scale + tile_row_start, # Int32 — global row index of this stage's row 0 + tile_col_start, # Int32 — global col index of this CTA's col 0 + M, N, + DTYPE, + TILE_X, + TILE_Y, + SCALE_DIM, # = SCALE_DIM_NVFP4 (16) +): + tidx, _, _ = cute.arch.thread_idx() + BLOCKS_PER_COL = TILE_Y // SCALE_DIM # e.g. 2 for MXFP8's TILE_Y=32 reused + + # Input thread layout — same `(1, TILE_X)` as MXFP8 colwise so the + # launch reuses TILE_X threads. Each thread owns one full input column + # (TILE_Y elements). + _, tv_layout_in = cute.make_layout_tv( + thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), + val_layout=cute.make_layout((TILE_Y, 1), stride=(1, 1)), + ) + sX_tv = cute.composition(sX_tile, tv_layout_in) + sX_thread = sX_tv[tidx, None] # (TILE_Y,) bf16/fp16, column tidx + + # The per-stage output tile arrives as a rank-1 nested mode (like the MXFP8 + # smem tiles); rebuild a flat rank-2 (TILE_X, TILE_Y // 2) view over the same + # bytes so the (col, byte) stores below index cleanly. + sO_col_2d = cute.make_tensor( + sO_col_tile.iterator, + cute.make_layout((TILE_X, TILE_Y // 2), stride=(TILE_Y // 2, 1)), + ) + + # Column elements are TILE_X apart in memory → widen to f32 once and + # do everything (amax, mul, fp4 cast) from f32 registers. + sX_thread_f32 = cute.make_rmem_tensor( + layout_or_shape=cute.make_layout((TILE_Y,), stride=(1,)), + dtype=Float32, + ) + if cutlass.const_expr(_is_packed16(DTYPE)): + # Widen via bits → f32 (avoids `Float32(bf16)` going through a + # generic conversion path; matches the MXFP8 colwise fast read). + kit = _packed16_kit(DTYPE) + sX_thread_i16 = cute.make_tensor( + cute.recast_ptr(sX_thread.iterator, dtype=Int16), + cute.make_layout((TILE_Y,), stride=(TILE_X,)), + ) + for i in cutlass.range_constexpr(TILE_Y): + sX_thread_f32[i] = kit.bits_to_f32(sX_thread_i16[i]) + else: + for i in cutlass.range_constexpr(TILE_Y): + sX_thread_f32[i] = Float32(sX_thread[i]) + + S_dec = Float32(1.0) / S_enc + fp4_max_inv = Float32(FP4_E2M1_MAX_RCP) + + # Loop over the BLOCKS_PER_COL scale-blocks down the column. Each + # iteration is a self-contained NVFP4 quantization of SCALE_DIM + # elements: amax → scale byte → block_scale_inv → SCALE_DIM/4 fp4x4 + # stores. The byte CONTENTS pack vertical pairs of input elements + # (2i, 2i+1) of column tidx; storage shape is transposed so writes are + # contiguous along the row axis of `sO_col_tile`. + for blk in cutlass.range_constexpr(BLOCKS_PER_COL): + base = blk * SCALE_DIM + # 1. amax over this scale block + amax_c = Float32(0.0) + for i in cutlass.range_constexpr(SCALE_DIM): + amax_c = cute.arch.fmax(amax_c, fabs_f32(sX_thread_f32[base + i])) + + # 2. Per-block E4M3 scale + S_dec_b_f32_pre = amax_c * (S_enc * fp4_max_inv) + S_dec_b_byte = cvt_f32_to_fp8e4m3(S_dec_b_f32_pre) + + # OOB guard: this thread's column extent vs N, this block's row + # extent vs M. Transposed scale tensor — col coord first. + scale_row_first = tile_row_start + base + scale_col = tile_col_start + tidx + if scale_row_first < M and scale_col < N: + mS_col_stage[(tidx, blk)] = Uint8(S_dec_b_byte) + + # 3. block_scale_inverse (same min/FLT_MAX clamp as rowwise) + S_dec_b_f32 = cvt_fp8e4m3_to_f32(S_dec_b_byte) + block_scale_inverse = cute.arch.fmin( + Float32(1.0) / (S_dec_b_f32 * S_dec), Float32(FP32_MAX)) + scale_2x = pack_f32x2(block_scale_inverse, block_scale_inverse) + + # 4. Cast SCALE_DIM elements → SCALE_DIM/4 × fp4x4. Output is + # transposed `(TILE_X, TILE_Y // 2)` uint8, so the two bytes of + # each fp4x4 land at consecutive byte offsets along row `tidx` — + # we can write Int16 directly instead of two scattered byte stores. + # Recast row `tidx` to an Int16 view of length TILE_Y // 4 per + # full thread (or SCALE_DIM // 4 per scale-block iter). + for w in cutlass.range_constexpr(SCALE_DIM // 4): + in01 = pack_f32x2(sX_thread_f32[base + 4 * w], + sX_thread_f32[base + 4 * w + 1]) + in23 = pack_f32x2(sX_thread_f32[base + 4 * w + 2], + sX_thread_f32[base + 4 * w + 3]) + quad = mul_cvt_f32x4_to_fp4x4(in01, in23, scale_2x) + # quad low 16 bits: byte0 = (fp4(elt 4w+1) << 4) | fp4(elt 4w+0), + # byte1 = (fp4(elt 4w+3) << 4) | fp4(elt 4w+2). Pair (4w, 4w+1) → + # transposed-output col `blk*(SCALE_DIM/2) + 2w`; pair (4w+2, 4w+3) → + # col `blk*(SCALE_DIM/2) + 2w + 1`. + byte_lo = quad & Int32(0xFF) + byte_hi = (quad >> Int32(8)) & Int32(0xFF) + out_col = blk * (SCALE_DIM // 2) + 2 * w + sO_col_2d[(tidx, out_col)] = Uint8(byte_lo) + sO_col_2d[(tidx, out_col + 1)] = Uint8(byte_hi) + + # No single-scalar amax return for colwise — multiple scale blocks + # per thread per call leaves no canonical value (and the PoC caller + # doesn't need it). Matches the C++ NVFP4 colwise path which does + # the cross-block aggregation outside. + return diff --git a/transformer_engine/common/cutedsl/cutedsl_utils.py b/transformer_engine/common/cutedsl/cutedsl_utils.py new file mode 100644 index 0000000000..cfddc014e9 --- /dev/null +++ b/transformer_engine/common/cutedsl/cutedsl_utils.py @@ -0,0 +1,34 @@ +import torch +import cutlass + +import transformer_engine +import transformer_engine_torch as tex + +_torch_to_cutlass_dtype = { + torch.uint8: cutlass.Uint8, + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} + +_str_to_cutlass_dtype = { + "e4m3": cutlass.Float8E4M3, + "e5m2": cutlass.Float8E5M2, + "none": None, +} + +_str_to_te_dtype = { + "e4m3": tex.DType.kFloat8E4M3, + "e5m2": tex.DType.kFloat8E5M2, + "none": None, +} + +def torch_to_cutlass_dtype(torch_dtype): + if torch_dtype not in _torch_to_cutlass_dtype: + raise ValueError(f"Unsupported torch dtype: {torch_dtype}") + return _torch_to_cutlass_dtype[torch_dtype] + +def str_to_te_dtype(str_dtype): + if str_dtype not in _str_to_te_dtype: + raise ValueError(f"Unsupported string dtype: {str_dtype}") + return _str_to_te_dtype[str_dtype] \ No newline at end of file diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index c32a561fb7..925b1a85df 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -115,6 +115,9 @@ enum NVTEScalingMode { /*! Single scale per block of 16 elements consecutive in either * rowwise or columnwise direction */ NVTE_NVFP4_1D_SCALING = 4, + /*! Flex scaling. The quantization is implemented by users via CuTeDSL. + */ + NVTE_FLEX_1D_SCALING = 5, NVTE_INVALID_SCALING = 100 }; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 779b145dd9..e9db6dfe4b 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -407,6 +407,44 @@ class NVFP4Quantizer : public Quantizer { cudaStream_t stream); }; +class FlexQuantizer : public Quantizer { + public: + explicit FlexQuantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_FLEX_1D_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional device = std::nullopt, bool pin_memory = false) const override; + + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + + std::pair convert_and_update_tensor(py::object tensor) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; + + /*! @brief Reconstruct a high-precision tensor by dispatching this + * quantizer's registered tvm-ffi dequantize_func. */ + void dequantize(const TensorWrapper& input, TensorWrapper& out); + + std::vector get_scale_shape(const std::vector& shape, DType dtype, bool columnwise) const; + std::vector get_scale_shape(size_t flat_first_dim, size_t flat_last_dim, DType dtype, bool columnwise) const; + + private: + // If nullopt, then skip quantizing that direction + std::optional dtype_row; + std::optional dtype_column; + std::string quantize_func; + std::string dequantize_func; + bool stochastic_rounding = false; +}; + std::unique_ptr convert_quantizer(py::handle quantizer); std::vector getTensorShape(const at::Tensor& t); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b4f899e1d..ae86c71cdc 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -341,6 +341,8 @@ py::object nvfp4_quantize_with_amax(const at::Tensor &tensor, py::handle quantiz py::object dequantize(const py::handle &input, DType otype); +py::object dequantize_with_quantizer(const py::handle &input, DType otype, py::handle quantizer); + py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, std::optional first_dims, std::optional tensor_offsets); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index aab5a87b9a..b6065be5eb 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -20,6 +20,7 @@ #include "common/util/system.h" #include "pybind.h" #include "transformer_engine/transformer_engine.h" +#include "tvm_ffi_bridge.h" namespace transformer_engine { namespace pytorch { @@ -442,6 +443,30 @@ py::object dequantize(const py::handle &input, transformer_engine::DType otype) return out; } +py::object dequantize_with_quantizer(const py::handle &input, transformer_engine::DType otype, + py::handle quantizer) { + init_extension(); + + // If other Quantizer types also support customized dequantization, they should be allowed here as well + NVTE_CHECK(!quantizer.is_none() && detail::IsFlexQuantizers(quantizer.ptr()), + "dequantize_with_quantizer expects a FlexQuantizer"); + + auto quantizer_cpp = convert_quantizer(quantizer); + auto *flex = static_cast(quantizer_cpp.get()); + + // Interpret the (multi-format) quantized input via the quantizer. + const auto &input_tensor = makeTransformerEngineTensor(input, quantizer); + + // Output is always a plain high-precision tensor; allocate with NoneQuantizer. + NoneQuantizer out_alloc{py::none()}; + const auto &shape = convertShape(input_tensor.shape()); + auto [out_tensor, out] = out_alloc.create_tensor(shape, otype); + + NVTE_SCOPED_GIL_RELEASE({ flex->dequantize(input_tensor, out_tensor); }); + + return out; +} + py::object group_dequantize(const py::handle &input, transformer_engine::DType otype) { using namespace pybind11::literals; init_extension(); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d6089b1e01..8902d0055b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -18,7 +18,6 @@ #include "../common.h" #include "../extensions.h" -#include "common.h" namespace transformer_engine::pytorch { @@ -35,6 +34,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; +PyTypeObject *FlexTensorPythonClass = nullptr; +PyTypeObject *FlexTensorStoragePythonClass = nullptr; +PyTypeObject *FlexQuantizerClass = nullptr; PyTypeObject *GroupedTensorPythonClass = nullptr; PyTypeObject *GroupedTensorStoragePythonClass = nullptr; std::once_flag extension_init_flag; @@ -103,6 +105,21 @@ void init_nvfp4_extensions() { "Internal error: could not initialize pyTorch NVFP4 extension."); } +void init_flex_extensions() { + auto flex_module = py::module_::import("transformer_engine.pytorch.tensor.flex_tensor"); + FlexQuantizerClass = reinterpret_cast( + PyObject_GetAttrString(flex_module.ptr(), "FlexQuantizer")); + FlexTensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(flex_module.ptr(), "FlexTensor")); + auto flex_base_module = + py::module_::import("transformer_engine.pytorch.tensor.storage.flex_tensor_storage"); + FlexTensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(flex_base_module.ptr(), "FlexTensorStorage")); + NVTE_CHECK(FlexQuantizerClass != nullptr && FlexTensorPythonClass != nullptr && + FlexTensorStoragePythonClass != nullptr, + "Internal error: could not initialize pyTorch Flex extension."); +} + void init_grouped_tensor_extension() { if (GroupedTensorPythonClass && GroupedTensorStoragePythonClass) return; auto grouped_tensor_module = @@ -125,6 +142,7 @@ void init_extension() { init_mxfp8_extension(); init_float8blockwise_extension(); init_nvfp4_extensions(); + init_flex_extensions(); init_grouped_tensor_extension(); }); } @@ -197,6 +215,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); + m.def("dequantize_with_quantizer", &transformer_engine::pytorch::dequantize_with_quantizer, + "Dequantize through the quantizer's registered dequantize_func", py::arg("input"), + py::arg("otype"), py::arg("quantizer")); m.def("create_empty_quantized_tensor", &transformer_engine::pytorch::create_empty_quantized_tensor, "Create an empty quantized tensor", py::arg("quantizer"), py::arg("shape"), diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 9e640537f9..c4fb2c266e 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -43,6 +43,9 @@ extern PyTypeObject *Float8BlockwiseQuantizerClass; extern PyTypeObject *NVFP4TensorPythonClass; extern PyTypeObject *NVFP4TensorStoragePythonClass; extern PyTypeObject *NVFP4QuantizerClass; +extern PyTypeObject *FlexTensorPythonClass; +extern PyTypeObject *FlexTensorStoragePythonClass; +extern PyTypeObject *FlexQuantizerClass; extern PyTypeObject *GroupedTensorPythonClass; extern PyTypeObject *GroupedTensorStoragePythonClass; @@ -81,6 +84,14 @@ inline bool IsNVFP4Tensor(PyObject *obj) { return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorStoragePythonClass; } +inline bool IsFlexQuantizers(PyObject *obj) { return Py_TYPE(obj) == FlexQuantizerClass; } + +inline bool IsFlexTensor(PyObject *obj) { + return Py_TYPE(obj) == FlexTensorPythonClass || Py_TYPE(obj) == FlexTensorStoragePythonClass; +} + +TensorWrapper NVTETensorFromFlexTensor(py::handle tensor, Quantizer *quantization_params); + TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); template @@ -113,7 +124,9 @@ constexpr std::array custom_types_converters = { std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, - CreateQuantizer)}; + CreateQuantizer), + std::make_tuple(IsFlexTensor, IsFlexQuantizers, NVTETensorFromFlexTensor, + CreateQuantizer)}; } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 5fc50953a1..24ecb3504f 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -11,7 +11,7 @@ #include "common/util/system.h" #include "pybind.h" #include "torch/torch.h" - +#include "tvm_ffi_bridge.h" // convert_to_dltype, DLTensorWrapper, tvm::ffi::Function namespace transformer_engine::pytorch { namespace { @@ -2623,4 +2623,588 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s return scale_shape; } +FlexQuantizer::FlexQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { + // The quantization format for row-wise direction + if (quantizer.attr("dtype_row").is_none()) { + this->dtype_row = std::nullopt; + this->rowwise_usage = false; + } else { + DType dtype_row = quantizer.attr("dtype_row").cast(); + if (is_fp8_dtype(dtype_row) || is_fp4_dtype(dtype_row)) { + this->dtype_row = dtype_row; + this->rowwise_usage = true; + } else { + NVTE_ERROR("Row-wise quantization for FlexQuantizer currently does not support this dtype"); + } + } + + // The quantization format for column-wise direction + if (quantizer.attr("dtype_column").is_none()) { + this->dtype_column = std::nullopt; + this->columnwise_usage = false; + } else { + DType dtype_col = quantizer.attr("dtype_column").cast(); + if (is_fp8_dtype(dtype_col) || is_fp4_dtype(dtype_col)) { + this->dtype_column = dtype_col; + this->columnwise_usage = true; + } else { + NVTE_ERROR("Column-wise quantization for FlexQuantizer currently does not support this dtype"); + } + } + + NVTE_CHECK(rowwise_usage || columnwise_usage, "FlexQuantizer should have at least one direction quantized."); + + // Name of the tvm-ffi global function (a CuTeDSL kernel the Python side + // already compiled + registered via tvm_ffi.register_global_func). quantize() + // resolves it from the registry and calls it directly — no Python on the hot + // path. (Resolution is deferred to quantize() so common.h stays free of the + // tvm-ffi headers.) + this->quantize_func = quantizer.attr("quantize_func").cast(); + this->dequantize_func = quantizer.attr("dequantize_func").cast(); + this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); +} + +void FlexQuantizer::set_quantization_params(TensorWrapper* tensor) const {} + +std::pair FlexQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional device_opt, + bool pin_memory) const { + const auto device = resolve_device(device_opt); + using namespace pybind11::literals; + + // Scaling factor format + const bool with_gemm_swizzled_scales = this->optimize_for_gemm; + + // Tensor dimensions + const std::vector shape_int64(shape.begin(), shape.end()); + const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); + const auto rowwise_scale_inv_shape = this->dtype_row + ? std::optional(get_scale_shape(flat_first_dim, flat_last_dim, *this->dtype_row, false)) + : std::nullopt; + const auto columnwise_scale_inv_shape = this->dtype_column + ? std::optional(get_scale_shape(flat_first_dim, flat_last_dim, *this->dtype_column, true)) + : std::nullopt; + + // Allocate tensors for quantized data and scaling factors + at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise_tensor; + at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise_tensor; + + const auto bit8_tensor_opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); + const auto bit32_tensor_opts = + at::TensorOptions().dtype(torch::kFloat32).device(device).pinned_memory(pin_memory); + + if (this->dtype_row) { + if (is_fp8_dtype(*this->dtype_row)) { + const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape->begin(), + rowwise_scale_inv_shape->end()); + rowwise_data_tensor = at::empty(shape_int64, bit8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + } else if (is_fp4_dtype(*this->dtype_row)) { + const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape->begin(), + rowwise_scale_inv_shape->end()); + rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + amax_rowwise_tensor = at::empty({1}, bit32_tensor_opts); + } + } + + if (this->dtype_column) { + if (is_fp8_dtype(*this->dtype_column)) { + const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape->begin(), + columnwise_scale_inv_shape->end()); + columnwise_data_tensor = at::empty(shape_int64, bit8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + } else if (is_fp4_dtype(*this->dtype_column)) { + const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape->begin(), + columnwise_scale_inv_shape->end()); + columnwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + amax_columnwise_tensor = at::empty({1}, bit32_tensor_opts); + } + } + + // Convert tensors to Python + auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { + return need_cast ? py::cast(tensor) : py::none(); + }; + + auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); + auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); + auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); + auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); + auto amax_rowwise_py = py_cast(amax_rowwise_tensor, rowwise_usage); + auto amax_columnwise_py = py_cast(amax_columnwise_tensor, columnwise_usage); + + py::object out_py; + + if (internal) { + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + py::tuple args(0); + kwargs["rowwise_data"] = rowwise_data_py; + kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; + kwargs["columnwise_data"] = columnwise_data_py; + kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; + kwargs["amax_rowwise"] = amax_rowwise_py; + kwargs["amax_columnwise"] = amax_columnwise_py; + kwargs["dtype_row"] = py::cast(this->dtype_row); + kwargs["dtype_column"] = py::cast(this->dtype_column); + kwargs["quantizer"] = this->quantizer; + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["fake_dtype"] = GetATenDType(dtype); + + PyObject* result = PyObject_Call(reinterpret_cast(FlexTensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create FlexTensorStorage instance"); + out_py = py::reinterpret_steal(result); + } else { + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + const auto stride_int64 = stride_from_shape(shape_int64); + kwargs["shape"] = py::cast(shape_int64); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["rowwise_data"] = rowwise_data_py; + kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; + kwargs["columnwise_data"] = columnwise_data_py; + kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; + kwargs["amax_rowwise"] = amax_rowwise_py; + kwargs["amax_columnwise"] = amax_columnwise_py; + kwargs["dtype_row"] = py::cast(this->dtype_row); + kwargs["dtype_column"] = py::cast(this->dtype_column); + kwargs["quantizer"] = this->quantizer; + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["device"] = py::cast(device); + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(FlexTensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create FlexTensor instance"); + out_py = py::reinterpret_steal(result); + } + + // Construct C++ tensor + TensorWrapper out_cpp(NVTE_FLEX_1D_SCALING); + if (rowwise_usage && this->dtype_row) { + out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), *this->dtype_row, shape); + if (is_fp8_dtype(*this->dtype_row)) { + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, *rowwise_scale_inv_shape); + } else if (is_fp4_dtype(*this->dtype_row)) { + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, *rowwise_scale_inv_shape); + out_cpp.set_amax(amax_rowwise_tensor.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise_tensor)); + } + } + if (columnwise_usage && this->dtype_column) { + if (is_fp8_dtype(*this->dtype_column)) { + out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), *this->dtype_column, shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, *columnwise_scale_inv_shape); + } else if (is_fp4_dtype(*this->dtype_column)) { + // Follow the pattern of NVFP4's columnwise data layout + std::vector shape_2d = {flat_first_dim, flat_last_dim}; + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1, + col_data_shape_fp4); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, + *columnwise_scale_inv_shape); + out_cpp.set_columnwise_amax(amax_columnwise_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + } + out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair FlexQuantizer::create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const { + // TODO: fix this + NVTE_ERROR("Not implemented yet"); +} + +std::pair FlexQuantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsFlexTensor(tensor.ptr()), "FlexQuantizer must output to FlexTensor."); + + // Scaling factor format + const bool with_gemm_swizzled_scales = this->optimize_for_gemm; + + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + auto amax_rowwise = get_tensor("_amax_rowwise"); + auto amax_columnwise = get_tensor("_amax_columnwise"); + NVTE_CHECK(rowwise_data || columnwise_data, "FlexTensor has no data."); + + // Tensor dimensions, shape means original shape + std::vector shape; + if (rowwise_data && this->dtype_row) { + if (is_fp4_dtype(*this->dtype_row)) { + shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + if (this->dtype_column && is_fp4_dtype(*this->dtype_column)) { + // If both rowwise and columnwise directions are NVFP4 quantized, check if they match + auto col_shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + NVTE_CHECK(get_2d_dims(shape) == get_2d_dims(col_shape), "NVFP4 row-wise data (shape=", shape, + ") and column-wise data (shape=", col_shape, ") do not match"); + } + } else if (is_fp8_dtype(*this->dtype_row)) { + shape = getTensorShape(*rowwise_data); + } + } else if (columnwise_data && this->dtype_column) { + if (is_fp4_dtype(*this->dtype_column)) { + shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + } else if (is_fp8_dtype(*this->dtype_column)) { + shape = getTensorShape(*columnwise_data); + } + } else { + NVTE_ERROR("FlexTensor has neither of rowwise and columnwise data"); + } + + const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); + const std::vector shape_int64(shape.begin(), shape.end()); + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + + // Coerce row-wise data + if (this->dtype_row) { + if (!rowwise_data) { + if (is_fp8_dtype(*this->dtype_row)) { + rowwise_data = at::empty(shape_int64, uint8_opts); + } else if (is_fp4_dtype(*this->dtype_row)) { + rowwise_data = at::empty(convert_shape_for_fp4(shape_int64), uint8_opts); + } else { + NVTE_ERROR("Unsupported dtype for row-wise quantization in FlexQuantizer: ", + static_cast(*this->dtype_row)); + } + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + if (is_fp8_dtype(*this->dtype_row)) { + const auto scale_inv_shape = get_scale_shape(shape, *this->dtype_row, false); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, uint8_opts); + } else if (is_fp4_dtype(*this->dtype_row)) { + const auto scale_inv_shape = get_scale_shape(shape, *this->dtype_row, false); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, uint8_opts); + } else { + NVTE_ERROR("Unsupported dtype for row-wise quantization in FlexQuantizer: ", + static_cast(*this->dtype_row)); + } + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + if (is_fp4_dtype(*this->dtype_row) && (!amax_rowwise || amax_rowwise->numel() != 1)) { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + amax_rowwise = at::empty({1}, opts); + tensor.attr("_amax_rowwise") = *amax_rowwise; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + if (amax_rowwise) { + amax_rowwise.reset(); + tensor.attr("_amax_rowwise") = py::none(); + } + } + + // Coerce column-wise data + if (this->dtype_column) { + if (!columnwise_data) { + if (is_fp8_dtype(*this->dtype_column)) { + columnwise_data = at::empty(shape_int64, uint8_opts); + } else if (is_fp4_dtype(*this->dtype_column)) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_int64_2d = {static_cast(flat_first_dim), + static_cast(flat_last_dim)}; + const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), uint8_opts); + } else { + NVTE_ERROR("Unsupported dtype for column-wise quantization in FlexQuantizer: ", + static_cast(*this->dtype_column)); + } + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + if (is_fp8_dtype(*this->dtype_column)) { + const auto scale_inv_shape = get_scale_shape(shape, *this->dtype_column, true); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, uint8_opts); + } else if (is_fp4_dtype(*this->dtype_column)) { + const auto scale_inv_shape = get_scale_shape(shape, *this->dtype_column, true); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, uint8_opts); + } else { + NVTE_ERROR("Unsupported dtype for column-wise quantization in FlexQuantizer: ", + static_cast(*this->dtype_column)); + } + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + if (is_fp4_dtype(*this->dtype_column) && (!amax_columnwise || amax_columnwise->numel() != 1)) { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + amax_columnwise = at::empty({1}, opts); + tensor.attr("_amax_columnwise") = *amax_columnwise; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + if (amax_columnwise) { + amax_columnwise.reset(); + tensor.attr("_amax_columnwise") = py::none(); + } + } + + // Coerce other attrs + tensor.attr("_dtype_row") = py::cast(this->dtype_row); + tensor.attr("_dtype_column") = py::cast(this->dtype_column); + tensor.attr("_with_gemm_swizzled_scales") = with_gemm_swizzled_scales; + + // Construct C++ flex tensor + TensorWrapper out_cpp(NVTE_FLEX_1D_SCALING); + if (this->dtype_row) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), *this->dtype_row, shape); + if (is_fp8_dtype(*this->dtype_row)) { + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*rowwise_scale_inv)); + } else if (is_fp4_dtype(*this->dtype_row)) { + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*rowwise_scale_inv)); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, getTensorShape(*amax_rowwise)); + } + } + if (this->dtype_column) { + if (is_fp8_dtype(*this->dtype_column)) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), *this->dtype_column, shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*columnwise_scale_inv)); + } else if (is_fp4_dtype(*this->dtype_column)) { + std::vector shape_2d = {flat_first_dim, flat_last_dim}; + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), DType::kFloat4E2M1, + col_data_shape_fp4); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*columnwise_scale_inv)); + out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, + getTensorShape(*amax_columnwise)); + } + } + out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void FlexQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } + NVTE_CHECK(!this->quantize_func.empty(), + "FlexQuantizer requires a registered quantize_func name"); + NVTE_CHECK(!this->dequantize_func.empty(), + "FlexQuantizer requires a registered quantize_func name"); + + // TODO: stochastic rounding needs a per-call RNG state (seed/offset) + // minted from the torch CUDA generator (cf. NVFP4Quantizer::quantize) and + // passed as a trailing kernel arg. Not wired yet — only the non-SR path works. + NVTE_CHECK(!this->stochastic_rounding, + "FlexQuantizer: stochastic rounding is not implemented yet"); + + // The CuTeDSL kernel runs on the current CUDA device; NVTEBasicTensor carries + // no device, so stamp the current one onto the DLTensors we hand to tvm-ffi. + const int32_t dev_index = static_cast(at::cuda::current_device()); + + // Kernel arg order (a generic MX-style block-scaling signature): + // (mX, mO_row, mS_row, mA_row, mO_col, mS_col, mA_col) + // mX is the high-precision input; the 6 outputs are `out`'s freshly-allocated + // buffers, grouped per direction as {data, scale_inv, amax}. Any buffer with a + // null data_ptr (a disabled direction, or amax for a format without one, e.g. + // MXFP8) is passed to the kernel as None, so a kernel only consumes the slots + // it actually needs. + constexpr int kNumArgs = 7; + const std::array nvte_args = { + input.get_rowwise_data(), // mX + out.get_rowwise_data(), // mO_row + out.get_rowwise_scale_inv(), // mS_row + out.get_amax(), // mA_row (primary / row-wise amax) + out.get_columnwise_data(), // mO_col + out.get_columnwise_scale_inv(), // mS_col + out.get_columnwise_amax(), // mA_col + }; + + // Named locals: each DLTensorWrapper owns the synthesized DLTensor shape/ + // stride storage (NVTEBasicTensor -> DLTensor) and must outlive the call. A + // null-data_ptr buffer stays nullopt -> to_ffi_arg yields None. + std::array, kNumArgs> dltensors; + for (int i = 0; i < kNumArgs; ++i) { + if (nvte_args[i].data_ptr != nullptr) { + dltensors[i].emplace(nvte_args[i], dev_index); + } + } + + // Trailing rng_state slot (seed/offset for stochastic rounding). SR isn't + // implemented yet (guarded above), so this stays None for now -- but keeping + // the slot fixes the kernel signature at 8 args: when SR lands, only this + // wrapper gets populated (minted from the torch generator), not the arg list. + std::optional rng_w; // None + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // When the quantizer produces GEMM-swizzled scales (optimize_for_gemm / + // with_gemm_swizzled_scales=True), the scale tensor is padded up to whole + // 128x4 tiles and the kernel only writes the valid blocks. The padded (OOB) + // entries must be zeroed because cuBLAS reads the full swizzled tiles during + // the GEMM. So passing swizzle=True means the scale buffers are zeroed here, + // each quantize call, before the kernel overwrites the valid region (matches + // quantize_mxfp8.cuh's per-call cudaMemsetAsync). Scales are 1-byte + // (E8M0/E4M3), so the element count is the byte count. + if (this->optimize_for_gemm) { + const auto &scale_row = dltensors[2]; // mS_row + if (scale_row) { + NVTE_CHECK_CUDA(cudaMemsetAsync(scale_row->data, 0, scale_row->numel(), stream)); + } + const auto &scale_col = dltensors[5]; // mS_col + if (scale_col) { + NVTE_CHECK_CUDA(cudaMemsetAsync(scale_col->data, 0, scale_col->numel(), stream)); + } + } + + // CUDA stream handle, passed as an int64 scalar. The CuTeDSL kernel takes an + // explicit stream arg (make_fake_stream, not env-stream); tvm-ffi decodes the + // int as the stream pointer (decode_param_stream -> allow_int_as_ptr). + const int64_t stream_handle = reinterpret_cast(stream); + + // Variadic dispatch: each to_ffi_arg(...) is an Optional (None + // when absent), packed by tvm-ffi's operator() and kept alive through the + // whole call expression. The wrappers back the views and outlive the call. + // (mX, mO_row, mS_row, mA_row, mO_col, mS_col, mA_col, rng_state, stream) + call_tvm_ffi(this->quantize_func, + to_ffi_arg(dltensors[0]), // input tensor + to_ffi_arg(dltensors[1]), // rowwise quantized data + to_ffi_arg(dltensors[2]), // rowwise scale_inv + to_ffi_arg(dltensors[3]), // rowwise amax + to_ffi_arg(dltensors[4]), // columnwise quantized data + to_ffi_arg(dltensors[5]), // columnwise scale_inv + to_ffi_arg(dltensors[6]), // columnwise amax + to_ffi_arg(rng_w), // rng_state (used for stochastic rounding if you need it) + stream_handle // CUDA stream handle as int64 + ); +} + +void FlexQuantizer::dequantize(const TensorWrapper& input, TensorWrapper& out) { + if (input.numel() == 0) { + return; + } + NVTE_CHECK(!this->dequantize_func.empty(), + "FlexQuantizer requires a registered dequantize_func name"); + + // The CuTeDSL kernel runs on the current CUDA stream; NVTEBasicTensor carries + // no device, so stamp the current one onto the DLTensors we hand to tvm-ffi. + const int32_t dev_index = static_cast(at::cuda::current_device()); + + // Kernel arg order (inverse of quantize: the high-precision destination + // first, then the quantized inputs grouped per direction as {data, scale, + // amax}): + // (mO, mX_row, mS_row, mA_row, mX_col, mS_col, mA_col) + // mO is the freshly-allocated high-precision output; the 6 inputs are the + // quantized buffers. Any buffer with a null data_ptr (a disabled direction, + // or amax for a format without one, e.g. MXFP8) is passed as None, so the + // kernel only consumes the slots it needs (typically just one direction is + // required to reconstruct). + constexpr int kNumArgs = 7; + const std::array nvte_args = { + out.get_rowwise_data(), // mO (high-precision destination) + input.get_rowwise_data(), // mX_row + input.get_rowwise_scale_inv(), // mS_row + input.get_amax(), // mA_row + input.get_columnwise_data(), // mX_col + input.get_columnwise_scale_inv(), // mS_col + input.get_columnwise_amax(), // mA_col + }; + + // Named locals: each DLTensorWrapper owns the synthesized DLTensor shape/ + // stride storage and must outlive the call. A null-data_ptr buffer stays + // nullopt -> to_ffi_arg yields None. + std::array, kNumArgs> w; + for (int i = 0; i < kNumArgs; ++i) { + if (nvte_args[i].data_ptr != nullptr) { + w[i].emplace(nvte_args[i], dev_index); + } + } + + // CUDA stream handle, passed as an int64 scalar (see quantize() above). + const int64_t stream_handle = + reinterpret_cast(static_cast(at::cuda::getCurrentCUDAStream())); + + // (mO, mX_row, mS_row, mA_row, mX_col, mS_col, mA_col, stream) + call_tvm_ffi(this->dequantize_func, to_ffi_arg(w[0]), to_ffi_arg(w[1]), to_ffi_arg(w[2]), + to_ffi_arg(w[3]), to_ffi_arg(w[4]), to_ffi_arg(w[5]), to_ffi_arg(w[6]), + stream_handle); +} + +std::vector FlexQuantizer::get_scale_shape(const std::vector& shape, DType dtype, + bool columnwise) const { + const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); + return get_scale_shape(flat_first_dim, flat_last_dim, dtype, columnwise); +} + +std::vector FlexQuantizer::get_scale_shape(size_t flat_first_dim, size_t flat_last_dim, + DType dtype, bool columnwise) const { + // Each direction uses its own format's scale-factor shape, mirroring the + // per-format MXFP8Quantizer / NVFP4Quantizer get_scale_shape implementations. + if (is_fp8_dtype(dtype)) { + NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, + "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, + " (got dims=", flat_first_dim, "x", flat_last_dim, ")"); + if (columnwise) { + return {roundup(flat_first_dim / MXFP8_BLOCK_SIZE, 4), roundup(flat_last_dim, 128)}; + } + return {roundup(flat_first_dim, 128), roundup(flat_last_dim / MXFP8_BLOCK_SIZE, 4)}; + } + if (is_fp4_dtype(dtype)) { + NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0 && flat_last_dim % NVFP4_BLOCK_SIZE == 0, + "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, + " (got dims=", flat_first_dim, "x", flat_last_dim, ")"); + if (columnwise) { + return {roundup(flat_last_dim, 128), roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4)}; + } + return {roundup(flat_first_dim, 128), roundup(flat_last_dim / NVFP4_BLOCK_SIZE, 4)}; + } + NVTE_ERROR("Unsupported dtype for FlexQuantizer::get_scale_shape"); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/tvm_ffi_bridge.h b/transformer_engine/pytorch/csrc/tvm_ffi_bridge.h new file mode 100644 index 0000000000..9da1246210 --- /dev/null +++ b/transformer_engine/pytorch/csrc/tvm_ffi_bridge.h @@ -0,0 +1,171 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_TVM_FFI_BRIDGE_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_TVM_FFI_BRIDGE_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "transformer_engine/transformer_engine.h" +#include "util/logging.h" + +namespace py = pybind11; + +// --------------------------------------------------------------------------- +// dtype conversion helpers — overload resolution picks by argument type. +// --------------------------------------------------------------------------- + +// NOTE: at::Tensor -> DLTensor goes through at::toDLPackNonOwning, which fills +// the DLDataType from the torch dtype automatically, so no c10::ScalarType +// overload is needed. Only TE's own NVTEBasicTensor (NVTEDType, no DLPack +// support) requires this manual mapping. +inline DLDataType convert_to_dltype(NVTEDType type) { + switch (type) { + case kNVTEFloat32: return DLDataType{kDLFloat, 32, 1}; + case kNVTEFloat16: return DLDataType{kDLFloat, 16, 1}; + case kNVTEBFloat16: return DLDataType{kDLBfloat, 16, 1}; + case kNVTEByte: return DLDataType{kDLUInt, 8, 1}; + case kNVTEInt32: return DLDataType{kDLInt, 32, 1}; + case kNVTEInt64: return DLDataType{kDLInt, 64, 1}; + // FP8 / E8M0 → raw 1-byte uint; the kernel interprets the bits. + case kNVTEFloat8E4M3: return DLDataType{kDLUInt, 8, 1}; + case kNVTEFloat8E5M2: return DLDataType{kDLUInt, 8, 1}; + case kNVTEFloat8E8M0: return DLDataType{kDLUInt, 8, 1}; + default: NVTE_ERROR("unsupported NVTEDType: ", static_cast(type)); + } +} + +// --------------------------------------------------------------------------- +// DLTensorWrapper — DLTensor with managed shape/strides storage. +// +// Subclassing DLTensor (a POD C struct) lets the wrapper IS-A DLTensor: you +// can take its address and pass it directly to `tvm::ffi::TensorView`. The +// shape/strides arrays the base struct points at are either borrowed from a +// PyTorch tensor (zero copy) or owned by the wrapper itself (when built +// from an NVTE tensor that doesn't store them in int64_t form). +// --------------------------------------------------------------------------- +class DLTensorWrapper : public DLTensor { + public: + // Zero-copy borrow via torch's own non-owning DLPack export: fills our + // base DLTensor in place (data/shape/strides/dtype/device/byte_offset) + // using torch's canonical field extraction — no heap alloc, no deleter, + // no refcount. shape/strides point into the at::Tensor's internal arrays, + // so the caller must keep `tensor` alive through any use of this wrapper. + explicit DLTensorWrapper(const at::Tensor &tensor) { + NVTE_CHECK(tensor.defined(), "DLTensorWrapper: undefined at::Tensor"); + at::toDLPackNonOwning(tensor, static_cast(this)); + this->numel_ = static_cast(tensor.numel()); + } + + // NVTEBasicTensor stores shape as size_t and has no strides. We allocate + // owned int64 buffers for both: copy the shape, synthesize row-major + // contiguous strides (TE tensors are always contiguous). + DLTensorWrapper(const NVTEBasicTensor &tensor, int32_t device_index) { + const int n = static_cast(tensor.shape.ndim); + shape_buf_ = std::make_unique(n); + strides_buf_ = std::make_unique(n); + int64_t stride = 1; + for (int i = n - 1; i >= 0; --i) { + shape_buf_[i] = static_cast(tensor.shape.data[i]); + strides_buf_[i] = stride; + stride *= shape_buf_[i]; + } + this->numel_ = stride; // product of all dims + this->data = tensor.data_ptr; + this->device = DLDevice{kDLCUDA, device_index}; + this->ndim = n; + this->dtype = convert_to_dltype(tensor.dtype); + this->shape = shape_buf_.get(); + this->strides = strides_buf_.get(); + this->byte_offset = 0; + } + + ~DLTensorWrapper() = default; + DLTensorWrapper(const DLTensorWrapper &) = delete; + DLTensorWrapper &operator=(const DLTensorWrapper &) = delete; + DLTensorWrapper(DLTensorWrapper &&) = default; + DLTensorWrapper &operator=(DLTensorWrapper &&) = default; + + // Number of elements (product of shape), cached at construction. For 1-byte + // dtypes (FP8 / E8M0 / E4M3) this equals the byte count. + int64_t numel() const { return this->numel_; } + + private: + int64_t numel_ = 0; + std::unique_ptr shape_buf_; + std::unique_ptr strides_buf_; +}; + +// --------------------------------------------------------------------------- +// Turn an optionally-present DLTensorWrapper into a tvm-ffi call argument: +// present -> TensorView over its (borrowed) DLTensor +// absent -> None +// tvm-ffi's TypeTraits> packs an empty Optional as None and a +// present one as the inner T (here, the TensorView). The returned Optional +// holds the TensorView by value, so when passed as a call-site argument it +// stays alive through the whole `fn(...)` expression (including CallPacked). +// The DLTensorWrapper it views must itself outlive the call (keep it as a +// named local — its synthesized shape/stride buffers back the DLTensor). +// --------------------------------------------------------------------------- +inline tvm::ffi::Optional to_ffi_arg( + const std::optional &wrapper) { + if (wrapper.has_value()) { + return tvm::ffi::TensorView(static_cast(&wrapper.value())); + } + return std::nullopt; +} + +// --------------------------------------------------------------------------- +// call_tvm_ffi — resolve a global tvm-ffi function by name and call it. +// +// Each argument is auto-packed into an AnyView by tvm-ffi's variadic +// `Function::operator()`. That operator keeps the call-site argument +// temporaries alive through its internal CallPacked (they are bound to +// forwarding references whose lifetime spans the full `fn(...)` expression), +// so passing TensorView / Optional temporaries here is safe — no +// need to park them in named arrays the way a manual AnyView[] + CallPacked +// loop would require. +// +// We use GetGlobal (returns std::nullopt on miss) rather than +// GetGlobalRequired so we can raise a domain-specific error. `fn_name` is the +// quantizer's cache key: it encodes every compile-time (constexpr) property the +// registered CuTeDSL kernel was specialized for (dtypes, per-direction formats, +// swizzle, baked shapes, ...). A registered kernel therefore *guarantees* that +// signature. A lookup miss means no kernel was registered for this exact +// signature — i.e. the caller is asking for a constexpr configuration the +// kernel author never compiled/registered, so the constexpr guarantee is +// broken. That is a setup bug (key mismatch), not a runtime input error, so we +// fail loudly with the offending name. +// --------------------------------------------------------------------------- +template +inline tvm::ffi::Any call_tvm_ffi(const std::string &fn_name, Args &&...args) { + std::optional fn = tvm::ffi::Function::GetGlobal(fn_name); + NVTE_CHECK(fn.has_value(), + "No tvm-ffi kernel registered under '", fn_name, + "'. This name is the quantizer's cache key, which encodes the " + "kernel's compile-time (constexpr) signature; a miss means the " + "registered kernel's constexpr guarantee does not match what is " + "being requested (no kernel was compiled/registered for this " + "exact configuration)."); + return (*fn)(std::forward(args)...); +} + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_TVM_FFI_BRIDGE_H_ diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index ddb85808a5..3758443d53 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -174,6 +174,52 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +TensorWrapper NVTETensorFromFlexTensor(py::handle tensor, Quantizer *quantizer) { + auto ret = TensorWrapper(NVTE_FLEX_1D_SCALING); + + const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); + const bool row_scaled_nvfp4 = tensor.attr("_row_scaled_nvfp4").cast(); + + NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for Flex Tensor."); + + // Row-scaled data + if (rowwise_usage) { + const DType dtype_row = tensor.attr("_dtype_row").cast(); + const auto &rowwise_data = tensor.attr("_rowwise_data").cast(); + const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + ret.set_rowwise_data(rowwise_data.data_ptr(), dtype_row, getTensorShape(rowwise_data)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, getTensorShape(scale_inv)); + if (is_fp4_dtype(dtype_row)) { + const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); + ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); + } + } + + // Column-scaled data + if (columnwise_usage) { + const DType dtype_column = tensor.attr("_dtype_column").cast(); + const auto &columnwise_data = tensor.attr("_columnwise_data").cast(); + const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + ret.set_columnwise_data(columnwise_data.data_ptr(), dtype_column, getTensorShape(columnwise_data)); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, getTensorShape(scale_inv)); + if (is_fp4_dtype(dtype_column)) { + const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); + ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, + getTensorShape(amax_columnwise)); + } + } + + // Scale layout + ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + + // Quantizer state + quantizer->set_quantization_params(&ret); + + return ret; +} + NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) { auto *quantizer_ptr = quantizer.ptr(); if (IsMXFP8Quantizers(quantizer_ptr)) { diff --git a/transformer_engine/pytorch/tensor/flex_tensor.py b/transformer_engine/pytorch/tensor/flex_tensor.py new file mode 100644 index 0000000000..0640eaf0fe --- /dev/null +++ b/transformer_engine/pytorch/tensor/flex_tensor.py @@ -0,0 +1,864 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with hybrid (mixed-format) quantized data""" +from __future__ import annotations +from collections.abc import Iterable +import math +from typing import Optional, Tuple, Union, Any +import warnings + +import torch +from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState +from transformer_engine.pytorch.ops.basic.quantize import Quantize +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe +from ..constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE +from ..utils import canonicalize_shape, devices_match, round_up_to_nearest_multiple +from .storage.flex_tensor_storage import FlexTensorStorage, _FromFlexFunc +from ..quantized_tensor import QuantizedTensor, Quantizer +from ._quantization_helpers import _IdentityFunc + +aten = torch.ops.aten + +class FlexQuantizer(Quantizer): + """Builder class for Flex tensors that are quantized in potentially both directions + + High-precision tensors (e.g. in FP32 or BF16) are quantized in row-wise and column-wise + directions with potentially different quantization formats. For example, it can be + quantized in MXFP8 in one direction and NVFP4 in another, or simply only quantized in only one. + + The quantization & dequantization logic is implemented in + the new experimental CuTeDSL kernel instead of the CUDA C++ kernel. + """ + + dtype_row: TE_DType + dtype_column: TE_DType + + def __init__( + self, + *, + dtype_row: Optional[TE_DType], + dtype_column: Optional[TE_DType], + quantize_func: str, + dequantize_func: str, + stochastic_rounding: bool = False, + ): + """ + Parameters + ---------- + dtype_row, dtype_column : Optional[TE_DType] + Per-direction quantized formats. ``None`` means that direction is + not produced (the corresponding data/scale/amax slots are passed as + None to the kernel). + quantize_func, dequantize_func : str + Names of the tvm-ffi global functions (registered CuTeDSL kernels) + that C++ ``FlexQuantizer::quantize`` / ``::dequantize`` resolve by + name (``Function::GetGlobal``) and call. + + These names double as the kernel **cache key**, and that key MUST + encode *exactly* everything the kernel was specialized for at + compile time — every constexpr the ``@cute.jit`` baked in: the + high-precision input dtype, the per-direction quantized formats, + the swizzle flag, and any baked tensor extents (e.g. a literal N). + Two requirements follow: + + * If a constexpr changes, the name MUST change too (recompile + + register under the new name). The name is the contract: a given + name *guarantees* a kernel specialized for that exact signature. + * Conversely, anything NOT baked (e.g. a sym-int dim) must NOT + appear in the name, or you fragment the cache and recompile + needlessly. + + Because the name is the only thing C++ carries to the call site, + there is no deeper signature check: if the constexpr signature and + the name ever disagree, the lookup simply misses. C++ then raises + (see ``call_tvm_ffi``): a miss means no kernel was registered for + this exact signature — the constexpr guarantee is broken — and that + is surfaced as an error rather than silently mis-dispatching. + stochastic_rounding : bool + If True, quantize() mints a fresh RNG state (seed/offset) from the + torch CUDA generator each call and passes it in the ``rng_state`` + slot below. The kernel must be a variant compiled with SR enabled. + (Not implemented yet — C++ throws when this is set.) + + tvm-ffi calling protocol (POSITIONAL — there is no runtime signature + check beyond what the compiled kernel itself enforces; both functions + MUST accept arguments in exactly this order): + + idx quantize_func dequantize_func dtype + --- ---------------------- ---------------------- ----------------- + 0 mX (high-prec INPUT) mO (high-prec OUTPUT) fp32/bf16/fp16 + 1 mO_row data OUTPUT mX_row data INPUT uint8 (fp8/fp4) + 2 mS_row scale OUTPUT mS_row scale INPUT uint8 (e8m0/e4m3) + 3 mA_row amax OUTPUT mA_row amax INPUT fp32 (None: MXFP8) + 4 mO_col data OUTPUT mX_col data INPUT uint8 (fp8/fp4) + 5 mS_col scale OUTPUT mS_col scale INPUT uint8 (e8m0/e4m3) + 6 mA_col amax OUTPUT mA_col amax INPUT fp32 (None: MXFP8) + 7 rng_state (SR seed) stream int64 / handle + 8 stream -- int64 handle + + Both functions share the SAME order for slots 0..6 — the only + difference is direction: for quantize, slot 0 is the high-precision + INPUT and slots 1..6 are the quantized OUTPUTS; for dequantize it is + reversed (slot 0 is the high-precision OUTPUT, slots 1..6 the quantized + INPUTS). Per-direction slots are grouped {data, scale_inv, amax}; a + disabled direction (dtype is None) or a format without amax (e.g. MXFP8) + passes None for that slot. quantize_func additionally takes ``rng_state`` + (slot 7, None unless stochastic_rounding) before the trailing CUDA + ``stream`` (passed as an int64 handle); dequantize_func omits rng_state, + so its ``stream`` is slot 7. + + What each slot is for (so a kernel can take only what it needs and the + rest are passed as None): + + - mX / mO (slot 0): the high-precision tensor. ALWAYS present — the + value being quantized (quantize: read) or reconstructed (dequantize: + write). + + - Row-wise group (slots 1..3) and column-wise group (slots 4..6): the + tensor quantized along two orientations. Block scaling is applied + along the last (contiguous) axis for the row-wise output and along + the first axis for the column-wise output. Training needs BOTH + because the forward and backward GEMMs consume the operand in + opposite layouts (e.g. ``x`` row-wise for fprop, column-wise for the + wgrad/dgrad GEMM). Inference or a one-sided use needs only ONE — set + the unused direction's dtype to None and its three slots arrive as + None. + + * data (slots 1 / 4): the packed FP8/FP4 bytes for that direction. + The actual quantized payload; produce it for any direction you + want to keep. + * scale_inv (slots 2 / 5): the per-block scale factors for that + direction (E8M0 for MXFP8, E4M3 for NVFP4). Always paired with + ``data`` — a block-scaled format is meaningless without its + scales, so if you produce ``data`` you must produce ``scale_inv``. + * amax (slots 3 / 6): the per-direction amax (max |x|). Only formats + whose scale derives from a tensor/global amax need this (e.g. + NVFP4's FP32 global scale, current-scaling recipes). MXFP8 picks a + power-of-two scale per 32-block directly from the block, so it + needs NO amax → pass None. + + - rng_state (quantize slot 7): seed/offset for the stochastic-rounding + RNG. Only needed if the kernel does stochastic rounding; otherwise + None (and leave ``stochastic_rounding=False``). Absent from + dequantize_func entirely. + + - stream (trailing): the CUDA stream to launch on, passed by C++ as an + int64 handle (decoded as an opaque pointer). Always present. + """ + if dtype_row is None and dtype_column is None: + raise ValueError( + "FlexQuantizer requires at least one direction to be quantized, " + "but both dtype_row and dtype_column are None." + ) + super().__init__(rowwise=dtype_row is not None, columnwise=dtype_column is not None) + self.dtype_row = dtype_row + self.dtype_column = dtype_column + self.quantize_func = quantize_func + self.dequantize_func = dequantize_func + # FIXME(flex): kernels are bound to one (cfg, shape). Take an is_valid(tensor) + # and a recompile(tensor) -> new func name per direction; in quantize_impl/ + # dequantize, recompile when invalid instead of erroring on shape change. + # When True, quantize() mints a fresh RNG state (seed/offset) from the + # torch CUDA generator each call and passes it to the kernel as the + # trailing arg, for stochastic rounding. The kernel must be a variant + # compiled with stochastic rounding enabled. + self.stochastic_rounding = stochastic_rounding + + def copy(self) -> FlexQuantizer: + """Create shallow copy""" + + quantizer = FlexQuantizer( + dtype_row=self.dtype_row, + dtype_column=self.dtype_column, + quantize_func=self.quantize_func, + dequantize_func=self.dequantize_func, + stochastic_rounding=self.stochastic_rounding, + ) + quantizer.internal = self.internal + quantizer.optimize_for_gemm = self.optimize_for_gemm + return quantizer + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag = None + ) -> QuantizedTensor: + assert isinstance(dst, FlexTensor), f"Cannot store quantized MXFP8 in {type(dst)} type." + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + # Update quantized tensor metadata + dst._dtype_row = self.dtype_row + dst._dtype_column = self.dtype_column + + return dst + + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + if inp.ndim < 2: + return False + if inp.shape[-1] % MXFP8_BLOCK_SCALING_SIZE != 0: + return False + if math.prod(inp.shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE != 0: + return False + return True + + def calibrate(self, tensor: torch.Tensor) -> None: + pass # Calibration is no-op since this supports only blockwise quantization, which doesn't require calibration. + + def get_scale_shape( + self, + shape: Iterable[int], + columnwise: bool, + dtype: TE_DType + ) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for Flex quantization. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For MXFP8 1D blockwise quantization, blocksize is 32 + For NXFP4 1D blockwise quantization, blocksize is 16 + - If columnwise: (round_to_multiple(K, 128), round_to_multiple(roundup(M / 16), 4)) + - If rowwise: (round_to_multiple(M, 128), round_to_multiple(roundup(K / 16), 4)) + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + if FlexTensor.is_nvfp4_dtype(dtype): + M, K = 1, 1 + M = math.prod(shape[:-1]) + K = shape[-1] + if columnwise: + outer = round_up_to_nearest_multiple(K, 128) + inner = round_up_to_nearest_multiple(math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4) + return (outer, inner) + outer = round_up_to_nearest_multiple(M, 128) + inner = round_up_to_nearest_multiple(math.ceil(K / NVFP4_BLOCK_SCALING_SIZE), 4) + return (outer, inner) + + elif FlexTensor.is_mxfp8_dtype(dtype): + if columnwise: + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + ) + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + ) + else: + raise ValueError(f"Unsupported quantization dtype {dtype} for Flex quantizer!") + + def get_data_shape( + self, + shape: Iterable[int], + columnwise: bool, + dtype: TE_DType, + ) -> Tuple[int, ...]: + """Physical (byte) shape of the quantized DATA buffer for one direction. + + Mirrors ``FlexQuantizer::create_tensor`` (C++): FP8 stores the logical + shape in both directions; NVFP4 packs two values per byte (and the + column-wise buffer is transposed), so the packed dim is halved. + + Parameters + ---------- + shape : Iterable[int] + Logical shape of the tensor being quantized. + columnwise : bool + Column-wise data buffer (True) or row-wise (False). + dtype : TE_DType + Per-direction quantized format. + + Returns + ------- + Tuple[int, ...] + Shape of the uint8 data buffer for that direction. + """ + if columnwise: + return _flex_columnwise_byte_shape(dtype, tuple(shape)) + return _flex_rowwise_byte_shape(dtype, tuple(shape)) + + @staticmethod + def get_columnwise_shape(shape: Iterable[int]) -> Tuple[int, ...]: + # TODO: probably need to fix this for dist + raise NotImplementedError("Not implemented yet.") + + def _get_compatible_recipe(self) -> Union[Recipe, None]: + """Get a compatible recipe for this quantizer, if any.""" + # TODO: really? + return None # Flex quantizer does not have a specific compatible recipe since it's orthogonal to the choice of recipe. + +class FlexTensor(FlexTensorStorage, QuantizedTensor): + """Tensor class for flex tensors with quantization in both directions. + + The tensor presents as having a standard, higher-precision dtype, but its + data is quantized -- potentially with a different format per direction. For + example, it can be quantized row-wise in NVFP4 and column-wise in MXFP8, or + use the same format in both directions. The per-direction format is fixed by + the quantizer that created the tensor (``_dtype_row`` / ``_dtype_column``). + + This is the autograd-visible PyTorch tensor subclass; the data buffers and + the operations on them live in the ``FlexTensorStorage`` mixin. + """ + + # NOTE: We reorder the *args so that we can instantiate a FlexTensorStorage with positional + # args, which significantly reduces the Pybind11 overhead when calling the constructor from C++. + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + amax_rowwise: Optional[torch.Tensor], + amax_columnwise: Optional[torch.Tensor], + dtype_row: Optional[TE_DType], + dtype_column: Optional[TE_DType], + quantizer: Optional[Quantizer], + with_gemm_swizzled_scales: bool, + **kwargs, + ): + return super().__new__( + cls, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + amax_rowwise, + amax_columnwise, + dtype_row, + dtype_column, + quantizer, + with_gemm_swizzled_scales, + *args, + **kwargs, + ) + + def __repr__(self, *, tensor_contents=None): + return ( + f"FlexTensor(dtype_row={self._dtype_row}, dtype_column={self._dtype_column}, " + f"data={self.dequantize()})" + ) + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Construct a plain PyTorch tensor from this FlexTensor. + + By default the resulting tensor's dtype is the FlexTensor's nominal + (high-precision) dtype. + """ + if dtype is None: + dtype = self.dtype + if torch.is_grad_enabled(): + return _FromFlexFunc.apply(self, dtype) + return _FromFlexFunc.forward(None, self, dtype) + + def _build_default_quantizer(self) -> Optional[Quantizer]: + """Build a default quantizer matching this tensor's per-direction formats.""" + return FlexQuantizer(dtype_row=self._dtype_row, dtype_column=self._dtype_column) + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> FlexTensor: + """Update Flex tensor data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + return super().quantize_(tensor, noop_flag=noop_flag) + + def detach(self) -> FlexTensor: + # pylint: disable=missing-function-docstring + return FlexTensor.make_like(self) + + def clone(self) -> FlexTensor: + # pylint: disable=missing-function-docstring + + rowwise_data = None + if self._rowwise_data is not None: + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + } + ) + + def view(self, *shape: Tuple[int]) -> FlexTensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> FlexTensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> FlexTensor: + """Returns tensor with data in provided memory format + Returns `self` if data is already in correct memory format. + + """ + if self._rowwise_data is not None and self._rowwise_data.is_contiguous( + memory_format=memory_format + ): + return self + if self._columnwise_data is not None and self._columnwise_data.is_contiguous( + memory_format=memory_format + ): + return self + raise ValueError("FlexTensor does not support different memory formats!") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # pylint: disable=missing-function-docstring + # TODO: handle aten ops (view/copy_/split/...) per direction, + # following MXFP8Tensor.__torch_dispatch__. + raise NotImplementedError( + f"FlexTensor.__torch_dispatch__ does not support {func} yet" + ) + + # ------------------------------------------------------------------ + # FSDP2 is not supported yet. Define the hooks so any accidental + # `fully_shard` use fails loudly here instead of cryptically deep in + # FSDP2's DTensor machinery. Single-GPU (non-FSDP2) use does not call these. + # ------------------------------------------------------------------ + def fsdp_pre_all_gather(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + # TODO: probably need to fix this for dist + raise NotImplementedError( + "FlexTensor does not support FSDP2 yet. " + "Use a single-GPU (non-fully_shard) setup for now." + ) + + def fsdp_post_all_gather(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + # TODO: probably need to fix this for dist + raise NotImplementedError( + "FlexTensor does not support FSDP2 yet. " + "Use a single-GPU (non-fully_shard) setup for now." + ) + + @classmethod + def _make_in_reduce_ex( + cls, + shape: torch.Size, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + amax_rowwise: Optional[torch.Tensor], + amax_columnwise: Optional[torch.Tensor], + dtype_row: Optional[TE_DType], + dtype_column: Optional[TE_DType], + dtype: torch.dtype, + quantizer: Optional[Quantizer], + with_gemm_swizzled_scales: bool = False, + ) -> FlexTensor: + """Build FlexTensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional arguments. + + """ + return FlexTensor( + shape=shape, + dtype=dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + dtype_row=dtype_row, + dtype_column=dtype_column, + quantizer=quantizer, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling""" + return ( + FlexTensor._make_in_reduce_ex, + ( + self.shape, + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + self._dtype_row, + self._dtype_column, + self.dtype, + self._quantizer, + self._with_gemm_swizzled_scales, + ), + ) + + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("FlexTensor has no data!") + + @property + def shape(self): + """Return the logical shape of the tensor. + + Each direction's buffer is laid out per its own format: MXFP8 stores the + logical shape directly, while NVFP4 is packed (and column-wise NVFP4 is + also transposed), so the byte shape is unpacked back to the logical shape. + """ + if self._rowwise_data is not None: + byte_shape = self._rowwise_data.shape + if self.is_nvfp4_dtype(self._dtype_row): + return torch.Size(byte_shape[:-1] + (byte_shape[-1] * 2,)) + return byte_shape + if self._columnwise_data is not None: + byte_shape = self._columnwise_data.shape + if self.is_nvfp4_dtype(self._dtype_column): + return torch.Size(byte_shape[1:-1] + (byte_shape[-1] * 2, byte_shape[0])) + return byte_shape + return torch.Tensor.size(self) + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + if self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("FlexTensor has no data!") + + def _get_data(self) -> FlexTensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes the quantized data if setting from a FlexTensor. Otherwise + casts to the flex format. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) + + # Just copy quantized data if other tensor is a FlexTensor + if isinstance(tensor, FlexTensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + FlexTensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(FlexTensor, type(self)).data.__set__(self, dummy_tensor) + + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + self._amax_rowwise = tensor._amax_rowwise + self._amax_columnwise = tensor._amax_columnwise + self._dtype_row = tensor._dtype_row + self._dtype_column = tensor._dtype_column + self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales + return + + # Quantize to the flex format + assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.update_quantized(tensor, self) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to the flex format when setting FlexTensor.data + data = property(_get_data, _set_data) + + +def _flex_rowwise_byte_shape(dtype: Optional[TE_DType], shape) -> tuple: + """Physical row-wise buffer shape for a logical ``shape``. + + MXFP8 stores the logical shape; NVFP4 packs 2 elements per byte, so the last + dim is halved. + """ + if FlexTensor.is_nvfp4_dtype(dtype): + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise NVFP4 quantized data for Flex tensor " + f"with shape={tuple(shape)} as byte array." + ) + return shape[:-1] + (shape[-1] // 2,) + return tuple(shape) + +def _flex_columnwise_byte_shape(dtype: Optional[TE_DType], shape) -> tuple: + """Physical column-wise buffer shape for a logical ``shape``. + + MXFP8 stores the logical shape (untransposed); NVFP4 is transposed and packed + -- (K, prod(leading_dims) // 2). + """ + if FlexTensor.is_nvfp4_dtype(dtype): + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise NVFP4 quantized data for Flex tensor " + f"with shape={tuple(shape)} as byte array." + ) + return (columnwise_shape[0], columnwise_shape[1] // 2) + return tuple(shape) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the FlexTensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: FlexTensor, + shape: Optional[list[int]] = None, + ) -> FlexTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + cur_shape = tensor.shape + if ctx is not None: + ctx.shape = cur_shape + if shape is None: + return tensor + + shape = canonicalize_shape(shape, cur_shape) + if shape[-1] != cur_shape[-1]: + warnings.warn( + "FlexTensor does not support reshaping inner dimension. " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + "If you are using this for FSDP2 without compiled_autograd_enabled," + "then ignore this warning. Since this view is not going to be used anywhere. ", + stacklevel=2, + ) + return tensor.dequantize().view(*shape) + + # Construct new tensor if shape is provided. Each direction is viewed + # according to its own format; a None dtype means that direction is not + # quantized (its data buffer is None, so it is skipped). + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + byte_shape = _flex_rowwise_byte_shape(tensor._dtype_row, shape) + new_rowwise_data = tensor._rowwise_data.view(byte_shape) + if tensor._columnwise_data is not None: + byte_shape = _flex_columnwise_byte_shape(tensor._dtype_column, shape) + new_columnwise_data = tensor._columnwise_data.view(byte_shape) + + # Construct tensor + return FlexTensor( + shape, + dtype=tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + amax_rowwise=tensor._amax_rowwise, + amax_columnwise=tensor._amax_columnwise, + quantizer=tensor._quantizer, + dtype_row=tensor._dtype_row, + dtype_column=tensor._dtype_column, + requires_grad=tensor.requires_grad, + with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, FlexTensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + byte_shape = _flex_rowwise_byte_shape(grad._dtype_row, ctx.shape) + new_rowwise_data = grad._rowwise_data.view(byte_shape) + if grad._columnwise_data is not None: + byte_shape = _flex_columnwise_byte_shape(grad._dtype_column, ctx.shape) + new_columnwise_data = grad._columnwise_data.view(byte_shape) + dgrad = FlexTensor( + ctx.shape, + dtype=grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + amax_rowwise=grad._amax_rowwise, + amax_columnwise=grad._amax_columnwise, + quantizer=grad._quantizer, + dtype_row=grad._dtype_row, + dtype_column=grad._dtype_column, + requires_grad=grad.requires_grad, + with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the FlexTensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: FlexTensor, + shape: Optional[list[int]] = None, + ) -> FlexTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + cur_shape = tensor.shape + if ctx is not None: + ctx.shape = cur_shape + if shape is None: + return tensor + + shape = canonicalize_shape(shape, cur_shape) + if shape[-1] != cur_shape[-1]: + warnings.warn( + "FlexTensor does not support reshaping inner dimension. " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + "If you are using this for FSDP2 without compiled_autograd_enabled," + "then ignore this warning. Since this view is not going to be used anywhere. ", + stacklevel=2, + ) + return tensor.dequantize().reshape(*shape) + + # Construct new tensor if shape is provided. Each direction is reshaped + # according to its own format; a None dtype means that direction is not + # quantized (its data buffer is None, so it is skipped). + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + byte_shape = _flex_rowwise_byte_shape(tensor._dtype_row, shape) + new_rowwise_data = tensor._rowwise_data.reshape(byte_shape) + if tensor._columnwise_data is not None: + byte_shape = _flex_columnwise_byte_shape(tensor._dtype_column, shape) + new_columnwise_data = tensor._columnwise_data.reshape(byte_shape) + + # Construct tensor + return FlexTensor( + shape, + dtype=tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + amax_rowwise=tensor._amax_rowwise, + amax_columnwise=tensor._amax_columnwise, + quantizer=tensor._quantizer, + dtype_row=tensor._dtype_row, + dtype_column=tensor._dtype_column, + requires_grad=tensor.requires_grad, + with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, FlexTensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + byte_shape = _flex_rowwise_byte_shape(grad._dtype_row, ctx.shape) + new_rowwise_data = grad._rowwise_data.reshape(byte_shape) + if grad._columnwise_data is not None: + byte_shape = _flex_columnwise_byte_shape(grad._dtype_column, ctx.shape) + new_columnwise_data = grad._columnwise_data.reshape(byte_shape) + dgrad = FlexTensor( + ctx.shape, + dtype=grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + amax_rowwise=grad._amax_rowwise, + amax_columnwise=grad._amax_columnwise, + quantizer=grad._quantizer, + dtype_row=grad._dtype_row, + dtype_column=grad._dtype_column, + requires_grad=grad.requires_grad, + with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + ) + return dgrad, None + return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ba46508d74..a9983a19f9 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -16,7 +16,7 @@ from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc from ..constants import DType -from ..utils import devices_match, round_up_to_nearest_multiple +from ..utils import canonicalize_shape, devices_match, round_up_to_nearest_multiple aten = torch.ops.aten @@ -716,18 +716,7 @@ def forward( if shape is None: return tensor - # Canonicalize shape - if not isinstance(shape, Iterable): - shape = [shape] - elif len(shape) == 1 and isinstance(shape[0], Iterable): - shape = shape[0] - if -1 in shape: - shape = list(shape) - d_inferred = -math.prod(ctx.shape) // math.prod(shape) - for i, d in enumerate(shape): - if d == -1: - shape[i] = d_inferred - break + shape = canonicalize_shape(shape, ctx.shape) if tensor._is_2D_scaled: # For the case of 2D scaled tensor, the last 2 dimensions should not change @@ -833,18 +822,7 @@ def forward( if shape is None: return tensor - # Canonicalize shape - if not isinstance(shape, Iterable): - shape = [shape] - elif len(shape) == 1 and isinstance(shape[0], Iterable): - shape = shape[0] - if -1 in shape: - shape = list(shape) - d_inferred = -math.prod(tensor.shape) // math.prod(shape) - for i, d in enumerate(shape): - if d == -1: - shape[i] = d_inferred - break + shape = canonicalize_shape(shape, ctx.shape) if tensor._is_2D_scaled: # For the case of 2D scaled tensor, the last 2 dimensions should not change diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d759aaf5c4..e05aa4bf27 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -15,7 +15,7 @@ from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE, DType -from ..utils import devices_match, round_up_to_nearest_multiple +from ..utils import canonicalize_shape, devices_match, round_up_to_nearest_multiple from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc @@ -898,18 +898,7 @@ def forward( if shape is None: return tensor - # Canonicalize shape - if not isinstance(shape, Iterable): - shape = [shape] - elif len(shape) == 1 and isinstance(shape[0], Iterable): - shape = shape[0] - if -1 in shape: - shape = list(shape) - d_inferred = -math.prod(ctx.shape) // math.prod(shape) - for i, d in enumerate(shape): - if d == -1: - shape[i] = d_inferred - break + shape = canonicalize_shape(shape, ctx.shape) if shape[-1] != ctx.shape[-1]: warnings.warn( "MXFP8Tensor does not support reshaping inner dimension. " @@ -991,18 +980,7 @@ def forward( if shape is None: return tensor - # Canonicalize shape - if not isinstance(shape, Iterable): - shape = [shape] - elif len(shape) == 1 and isinstance(shape[0], Iterable): - shape = shape[0] - if -1 in shape: - shape = list(shape) - d_inferred = -math.prod(ctx.shape) // math.prod(shape) - for i, d in enumerate(shape): - if d == -1: - shape[i] = d_inferred - break + shape = canonicalize_shape(shape, ctx.shape) if shape[-1] != ctx.shape[-1]: raise RuntimeError( "MXFP8Tensor does not support reshaping inner dimension " diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 5a2765b9f5..3c1b90190f 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -17,6 +17,7 @@ from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type, DType from ..utils import ( canonicalize_process_group, + canonicalize_shape, devices_match, round_up_to_nearest_multiple, ) @@ -994,18 +995,7 @@ def forward( if shape is None: return tensor - # Canonicalize shape - if not isinstance(shape, Iterable): - shape = [shape] - elif len(shape) == 1 and isinstance(shape[0], Iterable): - shape = shape[0] - if -1 in shape: - shape = list(shape) - d_inferred = -math.prod(cur_shape) // math.prod(shape) - for i, d in enumerate(shape): - if d == -1: - shape[i] = d_inferred - break + shape = canonicalize_shape(shape, cur_shape) if shape[-1] != cur_shape[-1]: warnings.warn( "NVFP4Tensor does not support reshaping inner dimension " @@ -1128,18 +1118,7 @@ def forward( if shape is None: return tensor - # Canonicalize shape - if not isinstance(shape, Iterable): - shape = [shape] - elif len(shape) == 1 and isinstance(shape[0], Iterable): - shape = shape[0] - if -1 in shape: - shape = list(shape) - d_inferred = -math.prod(cur_shape) // math.prod(shape) - for i, d in enumerate(shape): - if d == -1: - shape[i] = d_inferred - break + shape = canonicalize_shape(shape, cur_shape) if shape[-1] != cur_shape[-1]: warnings.warn( "NVFP4Tensor does not support reshaping inner dimension " diff --git a/transformer_engine/pytorch/tensor/storage/flex_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/flex_tensor_storage.py new file mode 100644 index 0000000000..52d65b81d1 --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/flex_tensor_storage.py @@ -0,0 +1,393 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""TODO: write comments +""" + +from __future__ import annotations +from typing import Optional, Dict, Any, Tuple, Union +from collections.abc import Iterable +import math +import warnings +import torch + +import transformer_engine_torch as tex # pylint: disable=unused-import +from transformer_engine_torch import ( + DType as TE_DType +) + +from ...quantized_tensor import QuantizedTensorStorage, Quantizer + +from ...constants import TE_DType as torch_to_transformer_engine_dtype # pylint: disable=unused-import + +from ...utils import _empty_tensor, canonicalize_shape + +class _FromFlexFunc(torch.autograd.Function): + """Cast from MXFP8 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: FlexTensorStorage, + dtype: torch.dtype, + quantizer: Quantizer + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + if tensor._rowwise_data is not None and tensor._rowwise_data.numel() == 0: + return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) + if tensor._columnwise_data is not None and tensor._columnwise_data.numel() == 0: + return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) + + if tensor._rowwise_data is not None or tensor._columnwise_data is not None: + return tex.dequantize_with_quantizer(tensor, dtype, quantizer) + raise ValueError("Cannot dequantize Flex tensor with no data") + + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None, None + + +class FlexTensorStorage(QuantizedTensorStorage): + """Mixin class that holds data attributes of FlexTensor. + + FlexTensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + The two directions may carry different quantization formats, tracked + by ``_dtype_row`` and ``_dtype_column``. + + """ + + # Row-scaled quantized data, None if not quantized in this direction + _rowwise_data: Optional[torch.Tensor] + # Column-scaled quantized data, None if not quantized in this direction + _columnwise_data: Optional[torch.Tensor] + # Block scaling factors for row-scaled data, None if not quantized in this direction + _rowwise_scale_inv: Optional[torch.Tensor] + # Block scaling factors for column-scaled data, None if not quantized in this direction + _columnwise_scale_inv: Optional[torch.Tensor] + # Input absolute maximum for row-scaled data if quantized in NVFP4 row-wisely + # None if otherwise + _amax_rowwise: Optional[torch.Tensor] + # Input absolute maximum for column-scaled data if quantized in NVFP4 column-wisely + # Nont if otherwise + _amax_columnwise: Optional[torch.Tensor] + + # Builder class for casting to the flex format + _quantizer: Optional[Quantizer] + # Quantization format of the row-wise direction + _dtype_row: Optional[TE_DType] + # Quantization format of the column-wise direction + _dtype_column: Optional[TE_DType] + # Whether scaling factors are in the swizzled format expected by GEMM + _with_gemm_swizzled_scales: bool + + def __new__( + cls, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + amax_rowwise: Optional[torch.Tensor], + amax_columnwise: Optional[torch.Tensor], + dtype_row: Optional[TE_DType], + dtype_column: Optional[TE_DType], + quantizer: Optional[Quantizer], + with_gemm_swizzled_scales: bool, + *args, + fake_dtype: Optional[torch.dtype] = None, + **kwargs, + ): + if cls is FlexTensorStorage: + instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 + else: + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + instance._amax_rowwise = amax_rowwise + instance._amax_columnwise = amax_columnwise + instance._quantizer = quantizer.copy() if quantizer is not None else None + instance._dtype_row = dtype_row + instance._dtype_column = dtype_column + instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales + + return instance + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + for t in ( + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + ): + if t is not None: + t.data = _empty_tensor() + + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data buffers from another FlexTensorStorage.""" + if not isinstance(src, FlexTensorStorage): + raise TypeError("copy_from_storage expects FlexTensorStorage") + if self._dtype_row != src._dtype_row or self._dtype_column != src._dtype_column: + raise RuntimeError("Flex dtype mismatch in copy_from_storage") + if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales: + raise RuntimeError("Scale layout mismatch in copy_from_storage") + + def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): + if dst is not None and src_tensor is not None: + dst.copy_(src_tensor) + + _copy_optional(self._rowwise_data, src._rowwise_data) + _copy_optional(self._columnwise_data, src._columnwise_data) + _copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv) + _copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv) + _copy_optional(self._amax_rowwise, src._amax_rowwise) + _copy_optional(self._amax_columnwise, src._amax_columnwise) + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "amax_rowwise": self._amax_rowwise, + "amax_columnwise": self._amax_columnwise, + "dtype_row": self._dtype_row, + "dtype_column": self._dtype_column, + "quantizer": self._quantizer, + "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "fake_dtype": self._dtype, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], FlexTensorStorage]: + """Prepare the tensor base for saving for backward""" + tensors = [ + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + ] + self._rowwise_data = None + self._columnwise_data = None + self._rowwise_scale_inv = None + self._columnwise_scale_inv = None + self._amax_rowwise = None + self._amax_columnwise = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + self._rowwise_scale_inv = tensors[2] + self._columnwise_scale_inv = tensors[3] + self._amax_rowwise = tensors[4] + self._amax_columnwise = tensors[5] + return tensors[6:] + + def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): + """Get this Tensor's data.""" + if rowwise_data and columnwise_data: + return self._rowwise_data, self._columnwise_data + if rowwise_data: + return self._rowwise_data + if columnwise_data: + return self._columnwise_data + raise ValueError("No data to get, both rowwise_data and columnwise_data are False") + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Dequantize to a higher precision.""" + if dtype is None: + dtype = self._dtype + if self._rowwise_data is not None and self._rowwise_data.numel() == 0: + return torch.empty(self.size(), dtype=dtype, device=self.device) + return _FromFlexFunc.forward(None, self, dtype, self) + + def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: + # pylint: disable=missing-function-docstring + shape = None + if self._rowwise_data is not None: + if self.is_mxfp8_dtype(self._dtype_row): + shape = self._rowwise_data.shape + elif self.is_nvfp4_dtype(self._dtype_row): + byte_shape = list(self._rowwise_data.size()) + shape = byte_shape[:-1] + [byte_shape[-1] * 2] + elif self._columnwise_data is not None: + if self.is_mxfp8_dtype(self._dtype_column): + shape = self._columnwise_data + elif self.is_nvfp4_dtype(self._dtype_column): + warnings.warn("Attempting to get shape of NVFP4 tensor with only column-wise data.") + byte_shape = list(self._columnwise_data.size()) + shape = byte_shape[1:-1] + [byte_shape[-1] * 2, byte_shape[0]] + + if shape is None: + raise RuntimeError("Attempted to get shape of Flex tensor with no data") + if dim is None: + return torch.Size(shape) + return shape[dim] + + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("FlexTensorStorage has no data!") + + def view(self, shape: torch.Size): + # pylint: disable=missing-function-docstring + + # Return input tensor if view not needed + cur_shape = self.size() + if shape is None or shape == cur_shape: + return self + + shape = canonicalize_shape(shape, cur_shape) + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "FlexTensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(cur_shape)} to {tuple(shape)})" + ) + + cur_rowwise_data = self._rowwise_data + cur_columnwise_data = self._columnwise_data + new_rowwise_data = None + new_columnwise_data = None + if self.is_mxfp8_dtype(self._dtype_row): + new_rowwise_data = cur_rowwise_data.view(*shape) + elif self.is_nvfp4_dtype(self._dtype_row): + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise NVFP4 quantized data for Flex tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = self._rowwise_data.view(byte_shape) + if self.is_mxfp8_dtype(self._dtype_column): + new_columnwise_data = cur_columnwise_data.view(*shape) + elif self.is_nvfp4_dtype(self._dtype_column): + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise NVFP4 quantized data for Flex tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = self._columnwise_data.view(byte_shape) + + return FlexTensorStorage( + rowwise_data=new_rowwise_data, + rowwise_scale_inv=self._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=self._columnwise_scale_inv, + amax_rowwise=self._amax_rowwise, + amax_columnwise=self._amax_columnwise, + dtype_row=self._dtype_row, + dtype_column=self._dtype_column, + quantizer=self._quantizer, + with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + fake_dtype=self._dtype, + ) + + def __repr__(self): + return ( + "FlexTensorStorage(" + f"dtype_row={self._dtype_row}, " + f"dtype_column={self._dtype_column}, " + f"rowwise_scale_inv={self._rowwise_scale_inv}, " + f"columnwise_scale_inv={self._columnwise_scale_inv}, " + f"amax_rowwise={self._amax_rowwise}," + f"amax_columnwise={self._amax_columnwise}," + ")" + ) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """ + TODO: figure out what to say here + """ + + # Default usage is based on available data + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + + # Update row-scaled data + if rowwise_usage: + if self._rowwise_data is None: + raise RuntimeError( + "Requested row-wise usage, but FlexTensor is missing row-scaled data" + ) + if self._rowwise_scale_inv is None: + raise RuntimeError( + "Requested row-wise usage, but FlexTensor is missing row-scaled scale-inverses" + ) + if self._amax_rowwise is None and self.is_nvfp4_dtype(self._dtype_row): + raise RuntimeError( + "Requested row-wise NVFP4 usage, but FlexTensor is missing per tensor" + " row-scaled scale-inverse" + ) + else: + self._rowwise_data = None + self._rowwise_scale_inv = None + self._amax_rowwise = None + + # Update column-scaled data + if columnwise_usage: + if self._columnwise_data is None: + raise RuntimeError( + "Requested column-wise usage, but FlexTensor is missing column-scaled data" + ) + if self._columnwise_scale_inv is None: + raise RuntimeError( + "Requested column-wise usage, " + "but FlexTensor is missing column-scaled scale-inverses" + ) + if self._amax_columnwise is None and self.is_nvfp4_dtype(self._dtype_column): + raise RuntimeError( + "Requested column-wise NVFP4 usage, " + "but FlexTensor is missing per tensor column-scaled scale-inverse" + ) + else: + self._columnwise_data = None + self._columnwise_scale_inv = None + self._amax_columnwise = None + + def get_usages(self) -> Dict[str, bool]: + """Get the usage of the tensor""" + return { + "rowwise": self._rowwise_data is not None, + "columnwise": self._columnwise_data is not None, + } + + @staticmethod + def is_mxfp8_dtype(dtype: TE_DType) -> bool: + return dtype == tex.DType.kFloat8E4M3 or dtype == tex.DType.kFloat8E5M2 + + @staticmethod + def is_nvfp4_dtype(dtype: TE_DType) -> bool: + return dtype == tex.DType.kFloat4E2M1 diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index ea592cd989..741535b156 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -16,7 +16,7 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype, DType -from ...utils import _empty_tensor +from ...utils import _empty_tensor, canonicalize_shape class _FromMXFP8Func(torch.autograd.Function): @@ -217,18 +217,7 @@ def view(self, shape: torch.Size): if shape is None or shape == cur_shape: return self - # Canonicalize shape - if not isinstance(shape, Iterable): - shape = [shape] - elif len(shape) == 1 and isinstance(shape[0], Iterable): - shape = shape[0] - if -1 in shape: - shape = list(shape) - d_inferred = -math.prod(cur_shape) // math.prod(shape) - for i, d in enumerate(shape): - if d == -1: - shape[i] = d_inferred - break + shape = canonicalize_shape(shape, cur_shape) if shape[-1] != cur_shape[-1]: raise RuntimeError( "MXFP8Tensor does not support reshaping inner dimension " diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 53bb5e7c11..2407dff1e7 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -283,18 +283,7 @@ def view(self, shape: torch.Size): if shape is None or shape == cur_shape: return self - # Canonicalize shape - if not isinstance(shape, Iterable): - shape = [shape] - elif len(shape) == 1 and isinstance(shape[0], Iterable): - shape = shape[0] - if -1 in shape: - shape = list(shape) - d_inferred = -math.prod(cur_shape) // math.prod(shape) - for i, d in enumerate(shape): - if d == -1: - shape[i] = d_inferred - break + shape = canonicalize_shape(shape, cur_shape) if shape[-1] != cur_shape[-1]: raise RuntimeError( "NVFP4Tensor does not support reshaping inner dimension " diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index cfb21e7bff..baf90e7a4a 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -8,7 +8,7 @@ import math import os import warnings -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union from contextlib import nullcontext import numpy as np import torch @@ -970,3 +970,17 @@ def convert_to_torch_tensor(tensor: Union[_WeakRefTensor, torch.Tensor]) -> torc f"Invalid type {type(x).__name__} to make weak ref. " "Valid types are: torch.Tensor, tuple, list, dict, int, float, bool, and None." ) + +def canonicalize_shape(shape, cur_shape): + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + return shape \ No newline at end of file From ce94edde4eb2eafeaca803ba5ad92b35935cbae7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jun 2026 01:13:34 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../cutedsl/test_flex_mxfp8_quantization.py | 43 +- .../common/cutedsl/cast/mxfp8_quantization.py | 244 +++++--- .../common/cutedsl/cast/quantization_utils.py | 588 +++++++++++------- .../common/cutedsl/cutedsl_utils.py | 4 +- transformer_engine/pytorch/csrc/common.h | 38 +- .../pytorch/csrc/extensions/pybind.cpp | 4 +- transformer_engine/pytorch/csrc/quantizer.cpp | 117 ++-- .../pytorch/csrc/tvm_ffi_bridge.h | 65 +- .../pytorch/csrc/type_converters.cpp | 3 +- .../pytorch/tensor/flex_tensor.py | 24 +- .../tensor/storage/flex_tensor_storage.py | 17 +- transformer_engine/pytorch/utils.py | 3 +- 12 files changed, 688 insertions(+), 462 deletions(-) diff --git a/tests/pytorch/cutedsl/test_flex_mxfp8_quantization.py b/tests/pytorch/cutedsl/test_flex_mxfp8_quantization.py index 39558d3754..b667e599db 100644 --- a/tests/pytorch/cutedsl/test_flex_mxfp8_quantization.py +++ b/tests/pytorch/cutedsl/test_flex_mxfp8_quantization.py @@ -18,19 +18,23 @@ # 2 aligned (no scale padding) + 2 padded (partial tiles); SHAPES = [(256, 256), (128, 512), (96, 224), (160, 96)] + def get_dtype_combinations(): dtype_row = ("e4m3", "e5m2", "none") dtype_column = ("e4m3", "e5m2", "none") return [(r, c) for r in dtype_row for c in dtype_column] + DTYPE_PAIRS = get_dtype_combinations() + def reference_quantize(x, fp8_type, rowwise, columnwise, swizzle): q = MXFP8Quantizer(fp8_dtype=str_to_te_dtype(fp8_type), rowwise=rowwise, columnwise=columnwise) q.optimize_for_gemm = swizzle # makes the native kernel emit swizzled scales ref = tex.quantize(x.clone(), q) return ref + @pytest.mark.parametrize("swizzle", [False, True]) @pytest.mark.parametrize("dtype_pair", DTYPE_PAIRS) @pytest.mark.parametrize("shape", SHAPES) @@ -57,15 +61,22 @@ def test_flex_mxfp8_bitexact(shape, dtype_pair, swizzle): # Reference for this direction uses THIS direction's dtype. ref = reference_quantize(x, dtype_row, rowwise=True, columnwise=False, swizzle=swizzle) assert ref._rowwise_data.shape == flex._rowwise_data.shape, "rowwise data shape mismatch" - assert ref._rowwise_scale_inv.shape == flex._rowwise_scale_inv.shape, "rowwise scale shape mismatch" - torch.testing.assert_close(flex._rowwise_data, ref._rowwise_data, rtol=0, atol=0) # bit-identical + assert ( + ref._rowwise_scale_inv.shape == flex._rowwise_scale_inv.shape + ), "rowwise scale shape mismatch" + torch.testing.assert_close( + flex._rowwise_data, ref._rowwise_data, rtol=0, atol=0 + ) # bit-identical if swizzle: - torch.testing.assert_close(flex._rowwise_scale_inv, ref._rowwise_scale_inv, rtol=0, atol=0) + torch.testing.assert_close( + flex._rowwise_scale_inv, ref._rowwise_scale_inv, rtol=0, atol=0 + ) else: torch.testing.assert_close( flex._rowwise_scale_inv[:scale_M, :scale_N], - ref._rowwise_scale_inv[:scale_M, :scale_N], - rtol=0, atol=0 + ref._rowwise_scale_inv[:scale_M, :scale_N], + rtol=0, + atol=0, ) else: assert flex._rowwise_data is None, "row=none must not produce rowwise data" @@ -73,20 +84,30 @@ def test_flex_mxfp8_bitexact(shape, dtype_pair, swizzle): if dtype_column != "none": scale_M, scale_N = M // MXFP8_BLOCK, N ref = reference_quantize(x, dtype_column, rowwise=False, columnwise=True, swizzle=swizzle) - assert ref._columnwise_data.shape == flex._columnwise_data.shape, "columnwise data shape mismatch" - assert ref._columnwise_scale_inv.shape == flex._columnwise_scale_inv.shape, "columnwise scale shape mismatch" - torch.testing.assert_close(flex._columnwise_data, ref._columnwise_data, rtol=0, atol=0) # bit-identical + assert ( + ref._columnwise_data.shape == flex._columnwise_data.shape + ), "columnwise data shape mismatch" + assert ( + ref._columnwise_scale_inv.shape == flex._columnwise_scale_inv.shape + ), "columnwise scale shape mismatch" + torch.testing.assert_close( + flex._columnwise_data, ref._columnwise_data, rtol=0, atol=0 + ) # bit-identical if swizzle: - torch.testing.assert_close(flex._columnwise_scale_inv, ref._columnwise_scale_inv, rtol=0, atol=0) + torch.testing.assert_close( + flex._columnwise_scale_inv, ref._columnwise_scale_inv, rtol=0, atol=0 + ) else: torch.testing.assert_close( flex._columnwise_scale_inv[:scale_M, :scale_N], - ref._columnwise_scale_inv[:scale_M, :scale_N], - rtol=0, atol=0 + ref._columnwise_scale_inv[:scale_M, :scale_N], + rtol=0, + atol=0, ) else: assert flex._columnwise_data is None, "col=none must not produce colwise data" + def test_flex_mxfp8_wrong_shape(): """A quantizer is compiled for a specific (M, N); using it on a different N must error rather than silently mis-quantize. diff --git a/transformer_engine/common/cutedsl/cast/mxfp8_quantization.py b/transformer_engine/common/cutedsl/cast/mxfp8_quantization.py index 350c8c61bc..83c19f5639 100644 --- a/transformer_engine/common/cutedsl/cast/mxfp8_quantization.py +++ b/transformer_engine/common/cutedsl/cast/mxfp8_quantization.py @@ -42,36 +42,43 @@ ) # MXFP8 settings -MXFP8_BLOCK_SIZE = 32 # Number of elements per MXFP8 scale block. They will share the same E8M0 scale factor +MXFP8_BLOCK_SIZE = ( + 32 # Number of elements per MXFP8 scale block. They will share the same E8M0 scale factor +) SCALE_DIM = MXFP8_BLOCK_SIZE # Double-buffering for async copy + compute overlap BUFFER_NUM = 2 # Vectorised access constants for bank-conflict avoidance (rowwise pass) -PACK_SIZE = 4 # Elements per vector load -WAVES = SCALE_DIM // PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total +PACK_SIZE = 4 # Elements per vector load +WAVES = ( + SCALE_DIM // PACK_SIZE +) # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total THREADS_PER_WARP = 32 TOTAL_BANKS_WIDTH = (32 * 4) // 1 # 32 banks × 4 bytes, in bytes (uint8 stride) THREADS_PER_BANK = TOTAL_BANKS_WIDTH // SCALE_DIM # 4 threads per bank # Tiling sizes -NUM_STAGES = 2 # Pipeline depth of the producer/consumer ring buffer for the TMA-G2S input loads (PipelineTmaAsync stage count) -NUM_TILES = 2 # Each CTA process 2 tiles along the Y (row, slowest-changing) dimension -TILE_Y = 32 # Each tile has 32 rows, so each CTA handles 32 * 2 rows in total -TILE_X = 64 # Each tile has 64 columns +NUM_STAGES = 2 # Pipeline depth of the producer/consumer ring buffer for the TMA-G2S input loads (PipelineTmaAsync stage count) +NUM_TILES = 2 # Each CTA process 2 tiles along the Y (row, slowest-changing) dimension +TILE_Y = 32 # Each tile has 32 rows, so each CTA handles 32 * 2 rows in total +TILE_X = 64 # Each tile has 64 columns # CTA size THREADS_PER_CHUNK = 64 NUM_WARPS = THREADS_PER_CHUNK // 32 + class MXFP8QuantizeConfig: - def __init__(self, - dtype: torch.dtype, - dtype_row: Union[Literal["e4m3", "e5m2", "none"]], - dtype_column: Union[Literal["e4m3", "e5m2", "none"]], - with_gemm_swizzled_scales=False): + def __init__( + self, + dtype: torch.dtype, + dtype_row: Union[Literal["e4m3", "e5m2", "none"]], + dtype_column: Union[Literal["e4m3", "e5m2", "none"]], + with_gemm_swizzled_scales=False, + ): self.DTYPE = dtype self.DTYPE_ROW = dtype_row self.ROWWISE = dtype_row != "none" @@ -84,6 +91,7 @@ def __init__(self, self.WITH_AMAX = False self.ACTIVATION = None + class MXFP8QuantizeKernel: """MXFP8 quantization with shared-memory tiling (rowwise, colwise, or both). @@ -109,9 +117,11 @@ def __init__(self, cfg): def __call__( self, mX: cute.Tensor, # Input tensor to quantize - mO_row: Optional[cute.Tensor], mS_row: Optional[cute.Tensor], # Rowwise data + scale + mO_row: Optional[cute.Tensor], + mS_row: Optional[cute.Tensor], # Rowwise data + scale mA_row: Optional[cute.Tensor], # Rowwise amax (None for MXFP8) - mO_col: Optional[cute.Tensor], mS_col: Optional[cute.Tensor], # Colwise data + scale + mO_col: Optional[cute.Tensor], + mS_col: Optional[cute.Tensor], # Colwise data + scale mA_col: Optional[cute.Tensor], # Colwise amax (None for MXFP8) rng_state: Optional[cute.Tensor], # SR seed/offset (None when SR disabled) stream: cuda.CUstream, # launch stream (C++ passes the handle as an int64 scalar) @@ -121,7 +131,7 @@ def __call__( cfg = self.cfg num_scale_cols = N // SCALE_DIM num_scale_rows = M // SCALE_DIM - + # Rewrap mS_row / mS_col with the GEMM-swizzled layout when requested. # Wrapper passes in a tensor with the compact (M, N/32):(N/32, 1) layout # (built from a compact fake-ptr at compile time), and we re-view the @@ -131,12 +141,12 @@ def __call__( # and swizzle_demo.svg for a visual of the byte permutation. if cutlass.const_expr(cfg.WITH_GEMM_SWIZZLED_SCALES): num_tiles_M = (M + 127) // 128 - num_tiles_SC = (num_scale_cols + 3) // 4 # = ceil(N / 128) - num_tiles_SR = (num_scale_rows + 3) // 4 # = ceil(M / 128) + num_tiles_SC = (num_scale_cols + 3) // 4 # = ceil(N / 128) + num_tiles_SR = (num_scale_rows + 3) // 4 # = ceil(M / 128) num_tiles_N = (N + 127) // 128 # row i = i_lo + 32 * (i_hi + 4 * tile_Y); col j = j_lo + 4 * tile_X. # Within one 128×4 tile: byte = i_lo*16 + i_hi*4 + j_lo. - + # Tile-major outer dims add (tile_Y * num_tiles_SC + tile_X) * 512. # For example, if M=256, N=512, then num_scale_cols = 16, num_scale_rows = 8, and num_tiles_M=2, num_tiles_SC=4, num_tiles_SR=2, num_tiles_N=4 # The swizzled layout is ((32, 4, 2), (4, 4)):((16, 4, 2048), (1, 512)) @@ -158,7 +168,7 @@ def __call__( stride=((1, 512), (16, 4, num_tiles_SR * 512)), ), ) - + # Divide by the STAGE tile (TILE_Y, TILE_X // SCALE_DIM), not the CTA # tile. Each CTA owns NUM_TILES consecutive row-tiles; the kernel walks # them by indexing GRID's row dim with `bidy * NUM_TILES + stage` (cute @@ -170,20 +180,24 @@ def __call__( # → SCALE_TILE = (32, 2):(16, 1) # The bigger (TILE_Y * NUM_TILES, ...) divide we used before tangles the # swizzle's (32, 4) row hierarchy under flatten + sub-divide chain. - + # Declare TMA descriptors on the host side. # make_tiled_tma_atom returns the UNTILED gmem tensor with basis strides. # Tile it inside the kernel with zipped_divide so each coord selects # one (TILE_Y, TILE_X) tile. smem_tile_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) cta_tiler = (TILE_Y, TILE_X) - + # Input: TMA G2S (bf16/fp16 → smem). op_load = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() tma_atom, tma_src = cute.nvgpu.cpasync.make_tiled_tma_atom( - op_load, mX, smem_tile_layout, cta_tiler, num_multicast=1, + op_load, + mX, + smem_tile_layout, + cta_tiler, + num_multicast=1, ) - + # Output: TMA S2G (uint8 smem → gmem) for both directions. Creating # both atoms unconditionally — if a direction is disabled the kernel # simply won't dispatch its copy, and the atom cost is negligible. @@ -195,27 +209,43 @@ def __call__( tma_dst_out_col = None if cutlass.const_expr(cfg.ROWWISE): tma_atom_out_row, tma_dst_out_row = cute.nvgpu.cpasync.make_tiled_tma_atom( - op_store, mO_row, out_smem_layout, cta_tiler, num_multicast=1, + op_store, + mO_row, + out_smem_layout, + cta_tiler, + num_multicast=1, ) if cutlass.const_expr(cfg.COLUMNWISE): tma_atom_out_col, tma_dst_out_col = cute.nvgpu.cpasync.make_tiled_tma_atom( - op_store, mO_col, out_smem_layout, cta_tiler, num_multicast=1, + op_store, + mO_col, + out_smem_layout, + cta_tiler, + num_multicast=1, ) - - # CUDA launches in (0,0), (1,0), (2,0)... order, so we should make N the leading dimension for better access pattern + + # CUDA launches in (0,0), (1,0), (2,0)... order, so we should make N the leading dimension for better access pattern # So consecutive blocks will move along the N dimension first, which is the innermost dimension in memory and we can use cache better grid = [ cute.ceil_div(Int32(N), TILE_X), cute.ceil_div(M, TILE_Y * NUM_TILES), ] - block = [THREADS_PER_CHUNK,] - + block = [ + THREADS_PER_CHUNK, + ] + self.kernel( - mX, mS_row, mS_col, None, # mAmax = None (no amax for the MXFP8 path) + mX, + mS_row, + mS_col, + None, # mAmax = None (no amax for the MXFP8 path) mX.element_type, - tma_atom, tma_src, - tma_atom_out_row, tma_dst_out_row, - tma_atom_out_col, tma_dst_out_col, + tma_atom, + tma_src, + tma_atom_out_row, + tma_dst_out_row, + tma_atom_out_col, + tma_dst_out_col, ).launch( grid=grid, block=block, @@ -230,9 +260,12 @@ def kernel( mS_col, mAmax, dtype: cutlass.Constexpr[Type[cutlass.Numeric]], - tma_atom, tma_src, # how to use TMA to copy the input - tma_atom_out_row, tma_dst_out_row, # how to use TMA to copy the rowwise output - tma_atom_out_col, tma_dst_out_col, # how to use TMA to copy the colwise output + tma_atom, + tma_src, # how to use TMA to copy the input + tma_atom_out_row, + tma_dst_out_row, # how to use TMA to copy the rowwise output + tma_atom_out_col, + tma_dst_out_col, # how to use TMA to copy the colwise output ): cfg = self.cfg @@ -254,6 +287,7 @@ def kernel( # negligible (8 bytes for NUM_WARPS=2) and we always allocate so the # struct doesn't fork on a 4th const-expr (cfg.WITH_AMAX) dimension. if cutlass.const_expr(cfg.ROWWISE and cfg.COLUMNWISE): + @cute.struct class SharedStorage: mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] @@ -267,7 +301,9 @@ class SharedStorage: cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 ] sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + elif cutlass.const_expr(cfg.ROWWISE and not cfg.COLUMNWISE): + @cute.struct class SharedStorage: mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] @@ -278,7 +314,9 @@ class SharedStorage: cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 ] sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + elif cutlass.const_expr(cfg.ROWWISE): + @cute.struct class SharedStorage: mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] @@ -289,7 +327,9 @@ class SharedStorage: cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 ] sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + else: + @cute.struct class SharedStorage: mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] @@ -300,6 +340,7 @@ class SharedStorage: cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 ] sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) @@ -355,15 +396,11 @@ class SharedStorage: producer_group=producer_group, consumer_group=consumer_group, tx_count=tx_count, - cta_layout_vmnk=None, # single-CTA, no cluster/multicast + cta_layout_vmnk=None, # single-CTA, no cluster/multicast ) - prod_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, NUM_STAGES - ) - cons_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, NUM_STAGES - ) + prod_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, NUM_STAGES) + cons_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, NUM_STAGES) M = mX.shape[0] N = mX.shape[1] @@ -379,8 +416,8 @@ class SharedStorage: # Partition sX/gX for the TMA atom (single-CTA, no cluster/multicast). tXsX, tXgX = cute.nvgpu.cpasync.tma_partition( tma_atom, - 0, # Use the only CTA to do the TMA copy - cute.make_layout(1), # This cluster only has 1 CTAs + 0, # Use the only CTA to do the TMA copy + cute.make_layout(1), # This cluster only has 1 CTAs sX, gX_tiled, ) @@ -437,7 +474,7 @@ class SharedStorage: # ---- Consumer: all threads quantize each completed tile. ---- for stage in cutlass.range(num_tiles, unroll=1): mainloop_pipeline.consumer_wait(cons_state) - sX_tile = sX[(None, stage)] # (TILE_Y, TILE_X) bf16 + sX_tile = sX[(None, stage)] # (TILE_Y, TILE_X) bf16 """ grid = [ @@ -448,7 +485,7 @@ class SharedStorage: """ # This is just block's x axis idx tile_idx_x = bidx - # Each CTA has `NUM_TILES` tiles. Each stage we need to obtain the tile for that specific stage. + # Each CTA has `NUM_TILES` tiles. Each stage we need to obtain the tile for that specific stage. # So the tile index along Y dimension is `bidy * NUM_TILES + stage` tile_idx_y = bidy * NUM_TILES + stage @@ -459,9 +496,13 @@ class SharedStorage: mS_col_stage = cute.flatten(mS_col[(None, (tile_idx_y, tile_idx_x))]) amax_c = self._process_colwise( - sX_tile, sO_col_tile, + sX_tile, + sO_col_tile, mS_col_stage, - tile_idx_y * TILE_Y, bidx * TILE_X, M, N, + tile_idx_y * TILE_Y, + bidx * TILE_X, + M, + N, ) if cutlass.const_expr(cfg.ROWWISE): @@ -479,9 +520,13 @@ class SharedStorage: # print(f"mS_row_stage: {mS_row_stage}\n") # print(f"mS_row_stage: {mS_row_stage}\n") amax_r = self._process_rowwise( - sX_tile, sO_row_tile, + sX_tile, + sO_row_tile, mS_row_stage, - tile_idx_y * TILE_Y, bidx * TILE_X, M, N, + tile_idx_y * TILE_Y, + bidx * TILE_X, + M, + N, ) # Make all smem stores (sO_row and/or sO_col) visible to the TMA @@ -517,7 +562,7 @@ class SharedStorage: # before the kernel returns. cute.arch.cp_async_bulk_wait_group(0, read=False) - # ---- amax block reduction + cross-CTA atomic ---------------------- + # ---- amax block reduction + cross-CTA atomic ---------------------- # 1) intra-warp: redux.sync.fmax.f32 (sm_80+, single instruction). # 2) cross-warp: NUM_WARPS shmem floats + sync_threads. # 3) cross-CTA: int-atomic-max on the f32 bit pattern. Since amax is @@ -541,19 +586,20 @@ class SharedStorage: cute.make_layout(1), ) cute.arch.atomic_max( - amax_i32.iterator, _bitcast_f32_to_i32(cta_amax), + amax_i32.iterator, + _bitcast_f32_to_i32(cta_amax), ) - @cute.jit def _process_rowwise( self, - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) - mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) - tile_row_start, # Int32 — global row of this stage's row 0 - tile_col_start, # Int32 — global col of this CTA's col 0 - M, N, # Int32 — full input extents, for OOB masking + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + tile_row_start, # Int32 — global row of this stage's row 0 + tile_col_start, # Int32 — global col of this CTA's col 0 + M, + N, # Int32 — full input extents, for OOB masking ): """Rowwise MXFP8 pass: thread `(tid_Y, tid_X) = (tidx % 32, tidx // 32)` owns one 32-element scale block (row `tid_Y`, columns `tid_X*32 .. +32`). @@ -587,18 +633,19 @@ def _process_rowwise( WAVES=WAVES, THREADS_PER_WARP=THREADS_PER_WARP, THREADS_PER_BANK=THREADS_PER_BANK, - PACK_SIZE=PACK_SIZE + PACK_SIZE=PACK_SIZE, ) @cute.jit def _process_colwise( self, - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) - mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) - tile_row_start, # Int32 — global row of this stage's row 0 - tile_col_start, # Int32 — global col of this CTA's col 0 - M, N, # Int32 — full input extents, for OOB masking + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + tile_row_start, # Int32 — global row of this stage's row 0 + tile_col_start, # Int32 — global col of this CTA's col 0 + M, + N, # Int32 — full input extents, for OOB masking ): """Colwise MXFP8 pass: thread `tidx` owns column `tidx` of the (32, 64) smem tile — 32 elements down. Writes quantized bytes into `sO_col_tile` @@ -616,7 +663,8 @@ def _process_colwise( max_norm_rcp, tile_row_start, tile_col_start, - M, N, + M, + N, ACTIVATION=None, DTYPE=cfg.DTYPE, FP8_DTYPE=cfg.DTYPE_COLUMN, @@ -626,18 +674,28 @@ def _process_colwise( SCALE_DIM=SCALE_DIM, ) + def _cfg_to_fn_name(cfg, M, N) -> str: """Deterministic registry key from (cfg, shape).""" - key = (cfg.DTYPE.__name__, cfg.DTYPE_ROW, cfg.DTYPE_COLUMN, - int(cfg.ROWWISE), int(cfg.COLUMNWISE), - int(cfg.WITH_GEMM_SWIZZLED_SCALES), int(cfg.WITH_AMAX), - cfg.ACTIVATION or "none", - M, N) + key = ( + cfg.DTYPE.__name__, + cfg.DTYPE_ROW, + cfg.DTYPE_COLUMN, + int(cfg.ROWWISE), + int(cfg.COLUMNWISE), + int(cfg.WITH_GEMM_SWIZZLED_SCALES), + int(cfg.WITH_AMAX), + cfg.ACTIVATION or "none", + M, + N, + ) h = hashlib.sha1(repr(key).encode()).hexdigest()[:16] return f"mxfp8_{h}" + _compile_cache_tvm_ffi: dict = {} + def _get_compiled_kernel(cfg, M, N): """Compile the kernel for THIS (cfg, M, N) with LITERAL shapes — every dimension is a constexpr int, so the AOT wrapper's per-arg type collapses @@ -658,24 +716,22 @@ def _get_compiled_kernel(cfg, M, N): # rowwise: (roundup(M, 128), roundup(N // 32, 4)) # columnwise: (roundup(M // 32, 4), roundup(N, 128)) SCALE_R = (((M + 127) // 128) * 128, ((N + 127) // 128) * 4) - SCALE_C = (((M + 127) // 128) * 4, ((N + 127) // 128) * 128) + SCALE_C = (((M + 127) // 128) * 4, ((N + 127) // 128) * 128) WS_M = (M + TILE_Y * NUM_TILES - 1) // (TILE_Y * NUM_TILES) # stride_order=(1, 0): row-major, dim 1 stride 1. 1D: (0,). - kw_rm16_2d = dict(stride_order=(1, 0), - memspace=cute.AddressSpace.gmem, assumed_align=16) - kw_rm4_2d = dict(stride_order=(1, 0), - memspace=cute.AddressSpace.gmem, assumed_align=4) - kw_rm4_1d = dict(stride_order=(0,), - memspace=cute.AddressSpace.gmem, assumed_align=4) + kw_rm16_2d = dict(stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16) + kw_rm4_2d = dict(stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) + kw_rm4_1d = dict(stride_order=(0,), memspace=cute.AddressSpace.gmem, assumed_align=4) + def fake(dtype, shape, kw): return cute.runtime.make_fake_compact_tensor(dtype, shape, **kw) - in_fake = fake(cfg.DTYPE, (M, N), kw_rm16_2d) - out_row_fake = fake(cute.Uint8, (M, N), kw_rm16_2d) if cfg.ROWWISE else None - scale_row_fake = fake(cute.Uint8, SCALE_R, kw_rm16_2d) if cfg.ROWWISE else None - out_col_fake = fake(cute.Uint8, (M, N), kw_rm16_2d) if cfg.COLUMNWISE else None - scale_col_fake = fake(cute.Uint8, SCALE_C, kw_rm16_2d) if cfg.COLUMNWISE else None + in_fake = fake(cfg.DTYPE, (M, N), kw_rm16_2d) + out_row_fake = fake(cute.Uint8, (M, N), kw_rm16_2d) if cfg.ROWWISE else None + scale_row_fake = fake(cute.Uint8, SCALE_R, kw_rm16_2d) if cfg.ROWWISE else None + out_col_fake = fake(cute.Uint8, (M, N), kw_rm16_2d) if cfg.COLUMNWISE else None + scale_col_fake = fake(cute.Uint8, SCALE_C, kw_rm16_2d) if cfg.COLUMNWISE else None # No amax / no SR for the MXFP8 path: these slots are None. The kernel's # __call__ takes them as Optional and dead-strips the amax/SR branches. amax_row_fake = None @@ -687,11 +743,15 @@ def fake(dtype, shape, kw): compiled = cute.compile( kernel_obj, - in_fake, # mX - out_row_fake, scale_row_fake, amax_row_fake, # mO_row, mS_row, mA_row - out_col_fake, scale_col_fake, amax_col_fake, # mO_col, mS_col, mA_col - rng_state_fake, # rng_state - stream_fake, # stream + in_fake, # mX + out_row_fake, + scale_row_fake, + amax_row_fake, # mO_row, mS_row, mA_row + out_col_fake, + scale_col_fake, + amax_col_fake, # mO_col, mS_col, mA_col + rng_state_fake, # rng_state + stream_fake, # stream options="--enable-tvm-ffi", ) cache[fn_name] = compiled @@ -738,10 +798,10 @@ def get_mxfp8_quantizer( # Dequant isn't implemented for this storage-only milestone, but # FlexQuantizer::quantize (C++) rejects an empty dequantize_func, so # carry a placeholder name. It is never resolved/called by quantize(). - dequantize_func=f"{fn_name}_dequant", # TODO: this is fake, implement dequant and remove the placeholder + dequantize_func=f"{fn_name}_dequant", # TODO: this is fake, implement dequant and remove the placeholder stochastic_rounding=False, ) # Storage-only milestone: no full PyTorch-tensor compatibility / GEMM yet. quantizer.internal = True quantizer.optimize_for_gemm = with_gemm_swizzled_scales - return quantizer \ No newline at end of file + return quantizer diff --git a/transformer_engine/common/cutedsl/cast/quantization_utils.py b/transformer_engine/common/cutedsl/cast/quantization_utils.py index 785185efaa..b54367f3b3 100644 --- a/transformer_engine/common/cutedsl/cast/quantization_utils.py +++ b/transformer_engine/common/cutedsl/cast/quantization_utils.py @@ -51,9 +51,11 @@ def float_to_e8m0(val: Float32, *, loc=None, ip=None) -> Int32: val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) rounded = val_i32 + Int32(0x7FFFFF) exponent = (rounded >> Int32(FP32_MANTISSA_BITS)) & Int32(0xFF) - return Int32(mlir_arith.minsi( - exponent.ir_value(loc=loc, ip=ip), - Int32(254).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return Int32( + mlir_arith.minsi( + exponent.ir_value(loc=loc, ip=ip), Int32(254).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) @dsl_user_op @@ -61,14 +63,20 @@ def exp2f_rcp(biased_exp: Int32, *, loc=None, ip=None) -> Float32: """2^(127 - biased_exp) with special-case handling.""" new_exp = (Int32(254) - biased_exp) << Int32(FP32_MANTISSA_BITS) result = _bitcast_i32_to_f32(new_exp, loc=loc, ip=ip) - for (cmp_val, repl_bits) in [(255, 0x7FFFFFFF), (254, 0x00400000), (0, 0x7F000000)]: - cond = mlir_arith.cmpi(mlir_arith.CmpIPredicate.eq, - biased_exp.ir_value(loc=loc, ip=ip), - Int32(cmp_val).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + for cmp_val, repl_bits in [(255, 0x7FFFFFFF), (254, 0x00400000), (0, 0x7F000000)]: + cond = mlir_arith.cmpi( + mlir_arith.CmpIPredicate.eq, + biased_exp.ir_value(loc=loc, ip=ip), + Int32(cmp_val).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) alt = _bitcast_i32_to_f32(Int32(repl_bits), loc=loc, ip=ip) - result = Float32(mlir_arith.select( - cond, alt.ir_value(loc=loc, ip=ip), - result.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result = Float32( + mlir_arith.select( + cond, alt.ir_value(loc=loc, ip=ip), result.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) return result @@ -76,14 +84,20 @@ def exp2f_rcp(biased_exp: Int32, *, loc=None, ip=None) -> Float32: def cvt_f32_to_fp8e4m3(val: Float32, *, loc=None, ip=None) -> Int32: """float32 -> fp8e4m3fn via PTX cvt.rn.satfinite.e4m3x2.f32.""" zero = Float32(0.0) - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], - "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - result_i32 = Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + result_i32 = Int32( + mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return result_i32 & Int32(0xFF) @@ -91,14 +105,20 @@ def cvt_f32_to_fp8e4m3(val: Float32, *, loc=None, ip=None) -> Int32: def cvt_f32_to_fp8e5m2(val: Float32, *, loc=None, ip=None) -> Int32: """float32 -> fp8e5m2 via PTX cvt.rn.satfinite.e5m2x2.f32.""" zero = Float32(0.0) - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], - "cvt.rn.satfinite.e5m2x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - result_i32 = Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + result_i32 = Int32( + mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return result_i32 & Int32(0xFF) @@ -107,25 +127,33 @@ def fma_f32(a: Float32, b: Float32, c: Float32, *, loc=None, ip=None) -> Float32 """`fma.rn.f32 d, a, b, c;` — single-instruction fused multiply-add matching nvcc's FFMA. Used for explicit `partial += a * b` patterns where we need the same rounding as TE's compiler-fused FFMA.""" - return Float32(llvm.inline_asm( - T.f32(), - [a.ir_value(loc=loc, ip=ip), - b.ir_value(loc=loc, ip=ip), - c.ir_value(loc=loc, ip=ip)], - "fma.rn.f32 $0, $1, $2, $3;", - "=f,f,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Float32( + llvm.inline_asm( + T.f32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip), c.ir_value(loc=loc, ip=ip)], + "fma.rn.f32 $0, $1, $2, $3;", + "=f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) @dsl_user_op def tanh_approx(val: Float32, *, loc=None, ip=None) -> Float32: """`tanh.approx.f32` — fast tanh approximation. Matches CUDA `__tanhf`.""" - return Float32(llvm.inline_asm( - T.f32(), - [val.ir_value(loc=loc, ip=ip)], - "tanh.approx.f32 $0, $1;", - "=f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Float32( + llvm.inline_asm( + T.f32(), + [val.ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) @dsl_user_op @@ -135,12 +163,17 @@ def pack_f32x2(lo: Float32, hi: Float32, *, loc=None, ip=None) -> Int64: Low 32 bits = `lo`, high 32 bits = `hi`. Uses `mov.b64 %dst, {%lo, %hi};` which lowers to a single register move — no actual memory traffic. """ - return Int64(llvm.inline_asm( - T.i64(), - [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], - "mov.b64 $0, {$1, $2};", - "=l,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int64( + llvm.inline_asm( + T.i64(), + [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], + "mov.b64 $0, {$1, $2};", + "=l,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) @dsl_user_op @@ -148,12 +181,17 @@ def pack_i32x2(lo: Int32, hi: Int32, *, loc=None, ip=None) -> Int64: """i32 sibling of `pack_f32x2` — concat two i32 into a single b64 register. Used by NVFP4 to glue two `(bf16,bf16)`/`(f16,f16)` Int32 packs into the `Int64` operand the `mul_cvt.*x4` PTX expects.""" - return Int64(llvm.inline_asm( - T.i64(), - [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], - "mov.b64 $0, {$1, $2};", - "=l,r,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int64( + llvm.inline_asm( + T.i64(), + [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], + "mov.b64 $0, {$1, $2};", + "=l,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) @dsl_user_op @@ -163,8 +201,7 @@ def _trunc_i32_to_i16(val: Int32, *, loc=None, ip=None) -> Int16: Lives here because the existing arith-dialect narrowing pattern requires loc/ip kwargs (see other `mlir_arith.trunci` callers); wrapping it as a `@dsl_user_op` lets `@cute.jit` bodies use it without plumbing those in.""" - return Int16(mlir_arith.trunci( - T.i16(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return Int16(mlir_arith.trunci(T.i16(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) @dsl_user_op @@ -185,12 +222,18 @@ def cvt_fp8e4m3_to_f32(byte_i32: Int32, *, loc=None, ip=None) -> Float32: "cvt.f32.f16 $0, lo_f16;\n\t" "}" ) - return Float32(llvm.inline_asm( - T.f32(), - [byte_i32.ir_value(loc=loc, ip=ip)], - asm, - "=f,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Float32( + llvm.inline_asm( + T.f32(), + [byte_i32.ir_value(loc=loc, ip=ip)], + asm, + "=f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + # --------------------------------------------------------------------------- # 16-bit packed input PTX kit (bf16 / f16) @@ -219,100 +262,140 @@ def _build_packed16_kit(in_fmt: str): @dsl_user_op def abs_max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: - return Int32(llvm.inline_asm( - T.i32(), - [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], - f"max.xorsign.abs.{in_fmt}x2 $0, $1, $2;", - "=r,r,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - + return Int32( + llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + @dsl_user_op def max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: - return Int32(llvm.inline_asm( - T.i32(), - [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], - f"max.{in_fmt}x2 $0, $1, $2;", - "=r,r,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32( + llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) @dsl_user_op def abs_max_scalar(a: Int16, b: Int16, *, loc=None, ip=None) -> Int16: - return Int16(llvm.inline_asm( - T.i16(), - [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], - f"max.xorsign.abs.{in_fmt} $0, $1, $2;", - "=h,h,h", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int16( + llvm.inline_asm( + T.i16(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt} $0, $1, $2;", + "=h,h,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) if in_fmt == "bf16": # bf16 == top 16 bits of f32 — widening is a free bit-shift. @dsl_user_op def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: - i32 = Int32(mlir_arith.extui( - T.i32(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + i32 = Int32(mlir_arith.extui(T.i32(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) @dsl_user_op def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: - return _bitcast_i32_to_f32( - (bits & Int32(0xFFFF)) << Int32(16), loc=loc, ip=ip) + return _bitcast_i32_to_f32((bits & Int32(0xFFFF)) << Int32(16), loc=loc, ip=ip) @dsl_user_op def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: # `(x >> 16) << 16` ≡ `x & 0xFFFF0000`, sidestepping signed-literal # issues. Sign bits from the arith-right shift get zeroed by the # left shift. - return _bitcast_i32_to_f32( - (bits >> Int32(16)) << Int32(16), loc=loc, ip=ip) + return _bitcast_i32_to_f32((bits >> Int32(16)) << Int32(16), loc=loc, ip=ip) @dsl_user_op def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: """Round f32 to bf16 precision (round-to-nearest-even), keep f32. Matches C++'s `static_cast(static_cast(elt))`.""" - bf16_bits = Int16(llvm.inline_asm( - T.i16(), [val.ir_value(loc=loc, ip=ip)], - "cvt.rn.bf16.f32 $0, $1;", - "=h,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - i32 = Int32(mlir_arith.extui( - T.i32(), bf16_bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + bf16_bits = Int16( + llvm.inline_asm( + T.i16(), + [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.bf16.f32 $0, $1;", + "=h,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + i32 = Int32( + mlir_arith.extui(T.i32(), bf16_bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) + else: # f16 has its own bit layout; widening requires `cvt.f32.f16`. @dsl_user_op def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: - return Float32(llvm.inline_asm( - T.f32(), [bits.ir_value(loc=loc, ip=ip)], - "cvt.f32.f16 $0, $1;", - "=f,h", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Float32( + llvm.inline_asm( + T.f32(), + [bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) @dsl_user_op def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: - lo_i16 = Int16(mlir_arith.trunci( - T.i16(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + lo_i16 = Int16( + mlir_arith.trunci(T.i16(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return bits_to_f32(lo_i16, loc=loc, ip=ip) @dsl_user_op def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: hi_shifted = bits >> Int32(16) - hi_i16 = Int16(mlir_arith.trunci( - T.i16(), hi_shifted.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + hi_i16 = Int16( + mlir_arith.trunci(T.i16(), hi_shifted.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return bits_to_f32(hi_i16, loc=loc, ip=ip) @dsl_user_op def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: """Round f32 to f16 precision, keep f32.""" - f16_bits = Int16(llvm.inline_asm( - T.i16(), [val.ir_value(loc=loc, ip=ip)], - "cvt.rn.f16.f32 $0, $1;", - "=h,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Float32(llvm.inline_asm( - T.f32(), [f16_bits.ir_value(loc=loc, ip=ip)], - "cvt.f32.f16 $0, $1;", - "=f,h", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + f16_bits = Int16( + llvm.inline_asm( + T.i16(), + [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.f16.f32 $0, $1;", + "=h,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Float32( + llvm.inline_asm( + T.f32(), + [f16_bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) def _build_mul_cvt(out_fmt: str, relu: bool = False): """Build a fused `x2 * f32x2 → fp8x2` PTX wrapper. @@ -333,21 +416,27 @@ def _build_mul_cvt(out_fmt: str, relu: bool = False): "mov.b64 vp0, {v1, v2};\n\t" "mul.f32x2 vp1, vp0, $2;\n\t" "mov.b64 {v2, v1}, vp1;\n\t" - f"cvt.rn.satfinite{".relu" if relu else ""}.{out_op}.f32 $0, v1, v2;\n\t" + f"cvt.rn.satfinite{'.relu' if relu else ''}.{out_op}.f32 $0, v1, v2;\n\t" "}" ) @dsl_user_op def fn(val_2x: Int32, scale_2x: Int64, *, loc=None, ip=None) -> Int32: - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_2x.ir_value(loc=loc, ip=ip), - scale_2x.ir_value(loc=loc, ip=ip)], - asm, - "=h,r,l", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [val_2x.ir_value(loc=loc, ip=ip), scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=h,r,l", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Int32( + mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) + return fn def mul_cvt_to_fp8x2(fp8_dtype: str, relu: bool = False): @@ -386,12 +475,17 @@ def mul_cvt_to_fp4x4(in_4x: Int64, scale_2x: Int64, *, loc=None, ip=None) -> Int "mov.b32 $0, {f0, f1, f0, f1};\n\t" "}" ) - return Int32(llvm.inline_asm( - T.i32(), - [in_4x.ir_value(loc=loc, ip=ip), scale_2x.ir_value(loc=loc, ip=ip)], - asm, - "=r,l,l", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32( + llvm.inline_asm( + T.i32(), + [in_4x.ir_value(loc=loc, ip=ip), scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=r,l,l", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) return SimpleNamespace( abs_max_x2=abs_max_x2, @@ -447,8 +541,8 @@ def _act_gelu(x: Float32) -> Float32: rather than the `tanh.approx.f32` PTX intrinsic — TE compiles activation kernels without `--use_fast_math` by default, so its `tanhf` is the IEEE-precise expansion.""" - A = Float32(0.79788456) # sqrt(2/π) truncated to TE's 8-digit literal - B = Float32(0.03567741) # = sqrt(2/π) · 0.044715, same truncation + A = Float32(0.79788456) # sqrt(2/π) truncated to TE's 8-digit literal + B = Float32(0.03567741) # = sqrt(2/π) · 0.044715, same truncation return x * (Float32(0.5) + Float32(0.5) * cute.math.tanh(x * (A + B * x * x))) @@ -466,40 +560,51 @@ def _act_silu(x: Float32) -> Float32: @dsl_user_op -def cvt_f32x2_to_fp8e4m3x2(val_hi: Float32, val_lo: Float32, relu: bool = False, - *, loc=None, ip=None) -> Int32: +def cvt_f32x2_to_fp8e4m3x2( + val_hi: Float32, val_lo: Float32, relu: bool = False, *, loc=None, ip=None +) -> Int32: """Convert two float32 values to two packed fp8e4m3fn bytes in one instruction. Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). This mirrors ptx::mul_cvt_2x which converts 2 values in one instruction. """ - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], - f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Int32(mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) @dsl_user_op -def cvt_f32x2_to_fp8e5m2x2(val_hi: Float32, val_lo: Float32, relu: bool = False, - *, loc=None, ip=None) -> Int32: +def cvt_f32x2_to_fp8e5m2x2( + val_hi: Float32, val_lo: Float32, relu: bool = False, *, loc=None, ip=None +) -> Int32: """e5m2 sibling of `cvt_f32x2_to_fp8e4m3x2`.""" - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], - f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Int32(mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) @dsl_user_op -def mul_cvt_f32x4_to_fp4x4(in01: Int64, in23: Int64, scale_2x: Int64, - *, loc=None, ip=None) -> Int32: +def mul_cvt_f32x4_to_fp4x4( + in01: Int64, in23: Int64, scale_2x: Int64, *, loc=None, ip=None +) -> Int32: """f32x4 sibling of `kit.mul_cvt_to_fp4x4` — for the NVFP4 colwise path where elements live on a strided column and we've already widened to f32 for the amax reduction. `in01` = pack(f32_0, f32_1), `in23` similarly.""" @@ -521,14 +626,21 @@ def mul_cvt_f32x4_to_fp4x4(in01: Int64, in23: Int64, scale_2x: Int64, "mov.b32 $0, {f0, f1, f0, f1};\n\t" "}" ) - return Int32(llvm.inline_asm( - T.i32(), - [in01.ir_value(loc=loc, ip=ip), - in23.ir_value(loc=loc, ip=ip), - scale_2x.ir_value(loc=loc, ip=ip)], - asm, - "=r,l,l,l", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32( + llvm.inline_asm( + T.i32(), + [ + in01.ir_value(loc=loc, ip=ip), + in23.ir_value(loc=loc, ip=ip), + scale_2x.ir_value(loc=loc, ip=ip), + ], + asm, + "=r,l,l,l", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) def _cvt_f32_to_fp8(fp8_dtype: str): @@ -548,19 +660,21 @@ def _cvt_f32x2_to_fp8x2(fp8_dtype: str): return cvt_f32x2_to_fp8e5m2x2 return cvt_f32x2_to_fp8e4m3x2 + @cute.jit def quantize_rowwise_mxfp8( - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) - mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) max_norm_rcp, - tile_row_start, # Int32 — global row index of this stage's row 0 - # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores - # for irregular shapes. - tile_col_start, # Int32 — global col index of this CTA's col 0 - # (= bidx * TILE_X). Same purpose. - M, N, # Int32 — full tensor extents; OOB threads skip their - # scale store. + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). Same purpose. + M, + N, # Int32 — full tensor extents; OOB threads skip their + # scale store. ACTIVATION, DTYPE, FP8_DTYPE, @@ -583,20 +697,20 @@ def quantize_rowwise_mxfp8( # print(f"sX_tile: {sX_tile}") # print(f"sO_row_tile: {sO_row_tile}") # print(f"mS_row_stage: {mS_row_stage}") - + tiler, tv_layout = cute.make_layout_tv( thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), - val_layout=cute.make_layout((1, SCALE_DIM), stride=(0, 1)) + val_layout=cute.make_layout((1, SCALE_DIM), stride=(0, 1)), ) # print(f"tv_layout: {tv_layout}") # print(f"tiler: {tiler}") - + sX_tv = cute.composition(sX_tile, tv_layout) sO_tv = cute.composition(sO_row_tile, tv_layout) # I/O Elements that belong to this thread - sX_thread = sX_tv[tidx, None] # shape (32,) bf16 - sO_thread = sO_tv[tidx, None] # shape (32,) uint8 + sX_thread = sX_tv[tidx, None] # shape (32,) bf16 + sO_thread = sO_tv[tidx, None] # shape (32,) uint8 # See https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%2832%2C+2%29%3A%282%2C1%29-%281%2C+32%29%3A%280%2C1%29 # print(f"sX_thread: {sX_thread}") @@ -606,13 +720,13 @@ def quantize_rowwise_mxfp8( # Each wave it writes 32 bytes = 8 uint32s, so in 4 waves we write all 32 quantized elements. sO_thread_u32 = cute.make_tensor( sO_thread_u32_ptr, - cute.make_layout((SCALE_DIM // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements + cute.make_layout((SCALE_DIM // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements ) # print(f"sO_thread_u32: {sO_thread_u32}") FUSE_RELU = cutlass.const_expr(ACTIVATION == "relu") # For this fast paht we can read in pack of 2 instead of reading individual f16 / bf16 element - _row_fast = (_is_packed16(DTYPE) and (ACTIVATION is None or FUSE_RELU)) + _row_fast = _is_packed16(DTYPE) and (ACTIVATION is None or FUSE_RELU) if cutlass.const_expr(_row_fast): # If no activation, f16 / bf16 and rowwise quantization, we can read 2 f16 / bf16 at once in a pack @@ -620,14 +734,16 @@ def quantize_rowwise_mxfp8( kit = _packed16_kit(DTYPE) sX_thread_rw_i32 = cute.make_tensor( cute.recast_ptr(sX_thread.iterator, dtype=Int32), - cute.make_layout((1, SCALE_DIM // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements + cute.make_layout((1, SCALE_DIM // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements ) # print(f"sX_thread_rw_i32: {sX_thread_rw_i32}") # Each wave we read 2 packed i32, which is 4 fp16/bf16 elements (PACK_SIZE) - # In total we have 8 waves where each wave reads + # In total we have 8 waves where each wave reads in_r = [[None, None] for _ in range(WAVES)] - bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group - offset = bank_group * 2 # Each bank group will read 2 i32 from their bank + bank_group = ( + tidx % THREADS_PER_WARP + ) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 2 # Each bank group will read 2 i32 from their bank for w in cutlass.range_constexpr(WAVES): idx = (w * 2 + offset) % (SCALE_DIM // 2) in_r[w][0] = sX_thread_rw_i32[0, idx] @@ -668,8 +784,10 @@ def quantize_rowwise_mxfp8( cute.make_layout((1, SCALE_DIM), stride=(0, 1)), ) in_r = [[None] * PACK_SIZE for _ in range(WAVES)] - bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group - offset = bank_group * 4 # Each bank group will read 4 f16 from their bank + bank_group = ( + tidx % THREADS_PER_WARP + ) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 4 # Each bank group will read 4 f16 from their bank if cutlass.const_expr(ACTIVATION is not None): op = _ACTIVATIONS[ACTIVATION] @@ -688,14 +806,16 @@ def quantize_rowwise_mxfp8( x = op(x) # If 16-bit input with activation, truncate to IType if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): - x = kit_act.truncate_f32(x) # TODO: Why not just qunatize from f32? + x = kit_act.truncate_f32(x) # TODO: Why not just qunatize from f32? in_r[w][e] = x if cutlass.const_expr(FUSE_RELU): - amax_r = cute.arch.fmax(amax_r, x) # For relu cases, we don't need abs since negative values will be 0 so they lose comparison automatically + amax_r = cute.arch.fmax( + amax_r, x + ) # For relu cases, we don't need abs since negative values will be 0 so they lose comparison automatically else: amax_r = cute.arch.fmax(amax_r, fabs_f32(x)) if cutlass.const_expr(FUSE_RELU): - amax_r = cute.arch.fmax(amax_r, Float32(0.0)) # If relu, the amax is at least 0 + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) # If relu, the amax is at least 0 # 2. E8M0 scale → gmem. mS_row's layout already encodes the swizzle # when cfg.WITH_GEMM_SWIZZLED_SCALES=True, so 2D access just works. @@ -714,7 +834,7 @@ def quantize_rowwise_mxfp8( mS_row_stage[(tidx // 2, tidx % 2)] = Uint8(biased_exp_r) # 3. scale + packed fp8 cast → smem as one u32 per wave. - inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale + inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale # Fetch the conversion function based on the FP8 format cvt_f32x2 = _cvt_f32x2_to_fp8x2(FP8_DTYPE) if cutlass.const_expr(_row_fast): @@ -724,8 +844,10 @@ def quantize_rowwise_mxfp8( # the per-wave mul_cvt consumes this directly. scale_2x = pack_f32x2(inv_scale_r, inv_scale_r) - bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group - offset = bank_group * 4 # Each bank group will write 4 fp8 to + bank_group = ( + tidx % THREADS_PER_WARP + ) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 4 # Each bank group will write 4 fp8 to for w in cutlass.range_constexpr(WAVES): idx = (w * 4 + offset) % SCALE_DIM idx = idx // 4 @@ -749,18 +871,20 @@ def quantize_rowwise_mxfp8( return amax_r + @cute.jit def quantize_colwise_mxfp8( - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) - mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) max_norm_rcp, - tile_row_start, # Int32 — global row index of this stage's row 0 - # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores - # for irregular shapes. - tile_col_start, # Int32 — global col index of this CTA's col 0 - # (= bidx * TILE_X). - M, N, # Int32 — full tensor extents. + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). + M, + N, # Int32 — full tensor extents. ACTIVATION, DTYPE, FP8_DTYPE, @@ -777,7 +901,7 @@ def quantize_colwise_mxfp8( tiler, tv_layout = cute.make_layout_tv( thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), - val_layout=cute.make_layout((SCALE_DIM, 1), stride=(1, 1)) + val_layout=cute.make_layout((SCALE_DIM, 1), stride=(1, 1)), ) # print(f"tv_layout: {tv_layout}") @@ -893,17 +1017,18 @@ def quantize_colwise_mxfp8( @cute.jit def quantize_rowwise_nvfp4( - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sO_row_tile, # (TILE_Y, TILE_X // 2) uint8 smem view (rowwise FP4 output) - mS_row_stage, # (TILE_Y, TILE_X // SCALE_DIM) uint8 — one E4M3 byte per (row, scale-block) - S_enc, # Float32 — precomputed global encode scale (uniform across threads) - tile_row_start, # Int32 — global row index of this stage's row 0 - tile_col_start, # Int32 — global col index of this CTA's col 0 - M, N, # Int32 — full tensor extents; OOB threads skip scale store + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X // 2) uint8 smem view (rowwise FP4 output) + mS_row_stage, # (TILE_Y, TILE_X // SCALE_DIM) uint8 — one E4M3 byte per (row, scale-block) + S_enc, # Float32 — precomputed global encode scale (uniform across threads) + tile_row_start, # Int32 — global row index of this stage's row 0 + tile_col_start, # Int32 — global col index of this CTA's col 0 + M, + N, # Int32 — full tensor extents; OOB threads skip scale store DTYPE, TILE_Y, TILE_X, - SCALE_DIM, # = SCALE_DIM_NVFP4 (16); explicit for symmetry with MXFP8 fn + SCALE_DIM, # = SCALE_DIM_NVFP4 (16); explicit for symmetry with MXFP8 fn ): """Rowwise NVFP4 pass — reuses the MXFP8 rowwise 64-thread layout. @@ -917,7 +1042,7 @@ def quantize_rowwise_nvfp4( """ tidx, _, _ = cute.arch.thread_idx() - SEG = TILE_X // 2 # 32 — elements per thread (half a row) + SEG = TILE_X // 2 # 32 — elements per thread (half a row) BLOCKS_PER_SEG = SEG // SCALE_DIM # 2 — NVFP4 scale-blocks in that segment # Same TV layout as MXFP8 rowwise (thr (TILE_Y, 2), val (1, SEG)); the @@ -933,8 +1058,8 @@ def quantize_rowwise_nvfp4( sX_tv = cute.composition(sX_tile, tv_layout_in) sO_tv = cute.composition(sO_row_tile, tv_layout_out) - sX_thread = sX_tv[tidx, None] # (SEG,) bf16/fp16 - sO_thread = sO_tv[tidx, None] # (SEG // 2,) uint8 + sX_thread = sX_tv[tidx, None] # (SEG,) bf16/fp16 + sO_thread = sO_tv[tidx, None] # (SEG // 2,) uint8 row = tidx // 2 seg = tidx % 2 @@ -958,7 +1083,7 @@ def quantize_rowwise_nvfp4( # block_scale_inverse → SCALE_DIM/4 fp4x4 stores. Identical math to the # colwise path, just along X with the packed-x2 amax fast read. for blk in cutlass.range_constexpr(BLOCKS_PER_SEG): - i32_base = blk * (SCALE_DIM // 2) # 8 Int32 per block + i32_base = blk * (SCALE_DIM // 2) # 8 Int32 per block in_r = [None] * (SCALE_DIM // 2) for w in cutlass.range_constexpr(SCALE_DIM // 2): in_r[w] = sX_thread_i32[i32_base + w] @@ -987,7 +1112,8 @@ def quantize_rowwise_nvfp4( # 3. block_scale_inverse = min(1 / (f32(S_dec_b_fp8) * S_dec), FLT_MAX). S_dec_b_f32 = cvt_fp8e4m3_to_f32(S_dec_b_byte) block_scale_inverse = cute.arch.fmin( - Float32(1.0) / (S_dec_b_f32 * S_dec), Float32(FP32_MAX)) + Float32(1.0) / (S_dec_b_f32 * S_dec), Float32(FP32_MAX) + ) scale_2x = pack_f32x2(block_scale_inverse, block_scale_inverse) # 4. Cast SCALE_DIM elements → SCALE_DIM/4 × fp4x4 (2 bytes each). @@ -1004,23 +1130,24 @@ def quantize_rowwise_nvfp4( @cute.jit def quantize_colwise_nvfp4( - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sO_col_tile, # (TILE_X, TILE_Y // 2) uint8 smem view — TRANSPOSED FP4 output - # (row `tidx` = column `tidx` of input; bytes pack vertically - # adjacent input rows). Matches TE's NVFP4 columnwise data - # storage shape `(N, M // 2)`, so the caller's TMA S2G goes - # straight to the right gmem layout with no extra transpose. - mS_col_stage, # (TILE_X, TILE_Y // SCALE_DIM) uint8 — one E4M3 byte per - # (col, scale-block-y). Also transposed to match TE's NVFP4 - # columnwise scale shape `(N, M // 16)`. - S_enc, # Float32 — precomputed global encode scale - tile_row_start, # Int32 — global row index of this stage's row 0 - tile_col_start, # Int32 — global col index of this CTA's col 0 - M, N, + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_X, TILE_Y // 2) uint8 smem view — TRANSPOSED FP4 output + # (row `tidx` = column `tidx` of input; bytes pack vertically + # adjacent input rows). Matches TE's NVFP4 columnwise data + # storage shape `(N, M // 2)`, so the caller's TMA S2G goes + # straight to the right gmem layout with no extra transpose. + mS_col_stage, # (TILE_X, TILE_Y // SCALE_DIM) uint8 — one E4M3 byte per + # (col, scale-block-y). Also transposed to match TE's NVFP4 + # columnwise scale shape `(N, M // 16)`. + S_enc, # Float32 — precomputed global encode scale + tile_row_start, # Int32 — global row index of this stage's row 0 + tile_col_start, # Int32 — global col index of this CTA's col 0 + M, + N, DTYPE, TILE_X, TILE_Y, - SCALE_DIM, # = SCALE_DIM_NVFP4 (16) + SCALE_DIM, # = SCALE_DIM_NVFP4 (16) ): tidx, _, _ = cute.arch.thread_idx() BLOCKS_PER_COL = TILE_Y // SCALE_DIM # e.g. 2 for MXFP8's TILE_Y=32 reused @@ -1033,7 +1160,7 @@ def quantize_colwise_nvfp4( val_layout=cute.make_layout((TILE_Y, 1), stride=(1, 1)), ) sX_tv = cute.composition(sX_tile, tv_layout_in) - sX_thread = sX_tv[tidx, None] # (TILE_Y,) bf16/fp16, column tidx + sX_thread = sX_tv[tidx, None] # (TILE_Y,) bf16/fp16, column tidx # The per-stage output tile arrives as a rank-1 nested mode (like the MXFP8 # smem tiles); rebuild a flat rank-2 (TILE_X, TILE_Y // 2) view over the same @@ -1093,7 +1220,8 @@ def quantize_colwise_nvfp4( # 3. block_scale_inverse (same min/FLT_MAX clamp as rowwise) S_dec_b_f32 = cvt_fp8e4m3_to_f32(S_dec_b_byte) block_scale_inverse = cute.arch.fmin( - Float32(1.0) / (S_dec_b_f32 * S_dec), Float32(FP32_MAX)) + Float32(1.0) / (S_dec_b_f32 * S_dec), Float32(FP32_MAX) + ) scale_2x = pack_f32x2(block_scale_inverse, block_scale_inverse) # 4. Cast SCALE_DIM elements → SCALE_DIM/4 × fp4x4. Output is @@ -1103,10 +1231,8 @@ def quantize_colwise_nvfp4( # Recast row `tidx` to an Int16 view of length TILE_Y // 4 per # full thread (or SCALE_DIM // 4 per scale-block iter). for w in cutlass.range_constexpr(SCALE_DIM // 4): - in01 = pack_f32x2(sX_thread_f32[base + 4 * w], - sX_thread_f32[base + 4 * w + 1]) - in23 = pack_f32x2(sX_thread_f32[base + 4 * w + 2], - sX_thread_f32[base + 4 * w + 3]) + in01 = pack_f32x2(sX_thread_f32[base + 4 * w], sX_thread_f32[base + 4 * w + 1]) + in23 = pack_f32x2(sX_thread_f32[base + 4 * w + 2], sX_thread_f32[base + 4 * w + 3]) quad = mul_cvt_f32x4_to_fp4x4(in01, in23, scale_2x) # quad low 16 bits: byte0 = (fp4(elt 4w+1) << 4) | fp4(elt 4w+0), # byte1 = (fp4(elt 4w+3) << 4) | fp4(elt 4w+2). Pair (4w, 4w+1) → diff --git a/transformer_engine/common/cutedsl/cutedsl_utils.py b/transformer_engine/common/cutedsl/cutedsl_utils.py index cfddc014e9..d7a9200e45 100644 --- a/transformer_engine/common/cutedsl/cutedsl_utils.py +++ b/transformer_engine/common/cutedsl/cutedsl_utils.py @@ -23,12 +23,14 @@ "none": None, } + def torch_to_cutlass_dtype(torch_dtype): if torch_dtype not in _torch_to_cutlass_dtype: raise ValueError(f"Unsupported torch dtype: {torch_dtype}") return _torch_to_cutlass_dtype[torch_dtype] + def str_to_te_dtype(str_dtype): if str_dtype not in _str_to_te_dtype: raise ValueError(f"Unsupported string dtype: {str_dtype}") - return _str_to_te_dtype[str_dtype] \ No newline at end of file + return _str_to_te_dtype[str_dtype] diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index e9db6dfe4b..45c928f397 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -408,33 +408,35 @@ class NVFP4Quantizer : public Quantizer { }; class FlexQuantizer : public Quantizer { - public: - explicit FlexQuantizer(const py::handle& quantizer); + public: + explicit FlexQuantizer(const py::handle& quantizer); - NVTEScalingMode get_scaling_mode() const override { return NVTE_FLEX_1D_SCALING; } + NVTEScalingMode get_scaling_mode() const override { return NVTE_FLEX_1D_SCALING; } - void set_quantization_params(TensorWrapper* tensor) const override; + void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional device = std::nullopt, bool pin_memory = false) const override; + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional device = std::nullopt, bool pin_memory = false) const override; - std::pair create_grouped_tensor( - size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, - size_t logical_last_dim) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; - std::pair convert_and_update_tensor(py::object tensor) const override; + std::pair convert_and_update_tensor(py::object tensor) const override; - void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; - /*! @brief Reconstruct a high-precision tensor by dispatching this + /*! @brief Reconstruct a high-precision tensor by dispatching this * quantizer's registered tvm-ffi dequantize_func. */ - void dequantize(const TensorWrapper& input, TensorWrapper& out); + void dequantize(const TensorWrapper& input, TensorWrapper& out); - std::vector get_scale_shape(const std::vector& shape, DType dtype, bool columnwise) const; - std::vector get_scale_shape(size_t flat_first_dim, size_t flat_last_dim, DType dtype, bool columnwise) const; + std::vector get_scale_shape(const std::vector& shape, DType dtype, + bool columnwise) const; + std::vector get_scale_shape(size_t flat_first_dim, size_t flat_last_dim, DType dtype, + bool columnwise) const; private: // If nullopt, then skip quantizing that direction diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8902d0055b..7b12507861 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -107,8 +107,8 @@ void init_nvfp4_extensions() { void init_flex_extensions() { auto flex_module = py::module_::import("transformer_engine.pytorch.tensor.flex_tensor"); - FlexQuantizerClass = reinterpret_cast( - PyObject_GetAttrString(flex_module.ptr(), "FlexQuantizer")); + FlexQuantizerClass = + reinterpret_cast(PyObject_GetAttrString(flex_module.ptr(), "FlexQuantizer")); FlexTensorPythonClass = reinterpret_cast(PyObject_GetAttrString(flex_module.ptr(), "FlexTensor")); auto flex_base_module = diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 24ecb3504f..1f3f81c5bb 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2648,11 +2648,13 @@ FlexQuantizer::FlexQuantizer(const py::handle& quantizer) : Quantizer(quantizer) this->dtype_column = dtype_col; this->columnwise_usage = true; } else { - NVTE_ERROR("Column-wise quantization for FlexQuantizer currently does not support this dtype"); + NVTE_ERROR( + "Column-wise quantization for FlexQuantizer currently does not support this dtype"); } } - NVTE_CHECK(rowwise_usage || columnwise_usage, "FlexQuantizer should have at least one direction quantized."); + NVTE_CHECK(rowwise_usage || columnwise_usage, + "FlexQuantizer should have at least one direction quantized."); // Name of the tvm-ffi global function (a CuTeDSL kernel the Python side // already compiled + registered via tvm_ffi.register_global_func). quantize() @@ -2678,12 +2680,14 @@ std::pair FlexQuantizer::create_tensor( // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); - const auto rowwise_scale_inv_shape = this->dtype_row - ? std::optional(get_scale_shape(flat_first_dim, flat_last_dim, *this->dtype_row, false)) - : std::nullopt; - const auto columnwise_scale_inv_shape = this->dtype_column - ? std::optional(get_scale_shape(flat_first_dim, flat_last_dim, *this->dtype_column, true)) - : std::nullopt; + const auto rowwise_scale_inv_shape = + this->dtype_row + ? std::optional(get_scale_shape(flat_first_dim, flat_last_dim, *this->dtype_row, false)) + : std::nullopt; + const auto columnwise_scale_inv_shape = + this->dtype_column + ? std::optional(get_scale_shape(flat_first_dim, flat_last_dim, *this->dtype_column, true)) + : std::nullopt; // Allocate tensors for quantized data and scaling factors at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise_tensor; @@ -2697,12 +2701,12 @@ std::pair FlexQuantizer::create_tensor( if (this->dtype_row) { if (is_fp8_dtype(*this->dtype_row)) { const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape->begin(), - rowwise_scale_inv_shape->end()); + rowwise_scale_inv_shape->end()); rowwise_data_tensor = at::empty(shape_int64, bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); } else if (is_fp4_dtype(*this->dtype_row)) { const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape->begin(), - rowwise_scale_inv_shape->end()); + rowwise_scale_inv_shape->end()); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); amax_rowwise_tensor = at::empty({1}, bit32_tensor_opts); @@ -2712,12 +2716,12 @@ std::pair FlexQuantizer::create_tensor( if (this->dtype_column) { if (is_fp8_dtype(*this->dtype_column)) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape->begin(), - columnwise_scale_inv_shape->end()); + columnwise_scale_inv_shape->end()); columnwise_data_tensor = at::empty(shape_int64, bit8_tensor_opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); } else if (is_fp4_dtype(*this->dtype_column)) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape->begin(), - columnwise_scale_inv_shape->end()); + columnwise_scale_inv_shape->end()); columnwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); amax_columnwise_tensor = at::empty({1}, bit32_tensor_opts); @@ -2781,8 +2785,8 @@ std::pair FlexQuantizer::create_tensor( kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["device"] = py::cast(device); py::tuple args(0); - PyObject* result = PyObject_Call(reinterpret_cast(FlexTensorPythonClass), - args.ptr(), kwargs.ptr()); + PyObject* result = + PyObject_Call(reinterpret_cast(FlexTensorPythonClass), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); } @@ -2796,16 +2800,20 @@ std::pair FlexQuantizer::create_tensor( if (rowwise_usage && this->dtype_row) { out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), *this->dtype_row, shape); if (is_fp8_dtype(*this->dtype_row)) { - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, *rowwise_scale_inv_shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, + *rowwise_scale_inv_shape); } else if (is_fp4_dtype(*this->dtype_row)) { - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, *rowwise_scale_inv_shape); - out_cpp.set_amax(amax_rowwise_tensor.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise_tensor)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, + *rowwise_scale_inv_shape); + out_cpp.set_amax(amax_rowwise_tensor.data_ptr(), DType::kFloat32, + getTensorShape(amax_rowwise_tensor)); } } if (columnwise_usage && this->dtype_column) { if (is_fp8_dtype(*this->dtype_column)) { out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), *this->dtype_column, shape); - out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, *columnwise_scale_inv_shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, + *columnwise_scale_inv_shape); } else if (is_fp4_dtype(*this->dtype_column)) { // Follow the pattern of NVFP4's columnwise data layout std::vector shape_2d = {flat_first_dim, flat_last_dim}; @@ -2813,7 +2821,7 @@ std::pair FlexQuantizer::create_tensor( out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1, col_data_shape_fp4); out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, - *columnwise_scale_inv_shape); + *columnwise_scale_inv_shape); out_cpp.set_columnwise_amax(amax_columnwise_tensor.data_ptr(), DType::kFloat32, std::vector{1}); } @@ -2825,9 +2833,9 @@ std::pair FlexQuantizer::create_tensor( } std::pair FlexQuantizer::create_grouped_tensor( - size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, - size_t logical_last_dim) const { + size_t num_tensors, const std::vector& logical_shape, DType dtype, py::object quantizer, + const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const { // TODO: fix this NVTE_ERROR("Not implemented yet"); } @@ -2863,8 +2871,9 @@ std::pair FlexQuantizer::convert_and_update_tensor( if (this->dtype_column && is_fp4_dtype(*this->dtype_column)) { // If both rowwise and columnwise directions are NVFP4 quantized, check if they match auto col_shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); - NVTE_CHECK(get_2d_dims(shape) == get_2d_dims(col_shape), "NVFP4 row-wise data (shape=", shape, - ") and column-wise data (shape=", col_shape, ") do not match"); + NVTE_CHECK(get_2d_dims(shape) == get_2d_dims(col_shape), + "NVFP4 row-wise data (shape=", shape, + ") and column-wise data (shape=", col_shape, ") do not match"); } } else if (is_fp8_dtype(*this->dtype_row)) { shape = getTensorShape(*rowwise_data); @@ -2900,12 +2909,12 @@ std::pair FlexQuantizer::convert_and_update_tensor( if (is_fp8_dtype(*this->dtype_row)) { const auto scale_inv_shape = get_scale_shape(shape, *this->dtype_row, false); const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); + scale_inv_shape.end()); rowwise_scale_inv = at::empty(scale_inv_shape_int64, uint8_opts); } else if (is_fp4_dtype(*this->dtype_row)) { const auto scale_inv_shape = get_scale_shape(shape, *this->dtype_row, false); const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); + scale_inv_shape.end()); rowwise_scale_inv = at::empty(scale_inv_shape_int64, uint8_opts); } else { NVTE_ERROR("Unsupported dtype for row-wise quantization in FlexQuantizer: ", @@ -2942,12 +2951,12 @@ std::pair FlexQuantizer::convert_and_update_tensor( // enforce 2D shape to avoid [S, B, H] shape and B and be 1 // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero std::vector shape_int64_2d = {static_cast(flat_first_dim), - static_cast(flat_last_dim)}; + static_cast(flat_last_dim)}; const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), uint8_opts); } else { NVTE_ERROR("Unsupported dtype for column-wise quantization in FlexQuantizer: ", - static_cast(*this->dtype_column)); + static_cast(*this->dtype_column)); } tensor.attr("_columnwise_data") = *columnwise_data; } @@ -2955,16 +2964,16 @@ std::pair FlexQuantizer::convert_and_update_tensor( if (is_fp8_dtype(*this->dtype_column)) { const auto scale_inv_shape = get_scale_shape(shape, *this->dtype_column, true); const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); + scale_inv_shape.end()); columnwise_scale_inv = at::empty(scale_inv_shape_int64, uint8_opts); } else if (is_fp4_dtype(*this->dtype_column)) { const auto scale_inv_shape = get_scale_shape(shape, *this->dtype_column, true); const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); + scale_inv_shape.end()); columnwise_scale_inv = at::empty(scale_inv_shape_int64, uint8_opts); } else { NVTE_ERROR("Unsupported dtype for column-wise quantization in FlexQuantizer: ", - static_cast(*this->dtype_column)); + static_cast(*this->dtype_column)); } tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; } @@ -2973,7 +2982,7 @@ std::pair FlexQuantizer::convert_and_update_tensor( amax_columnwise = at::empty({1}, opts); tensor.attr("_amax_columnwise") = *amax_columnwise; } - } else { // columnwise_usage == false + } else { // columnwise_usage == false if (columnwise_data) { columnwise_data.reset(); tensor.attr("_columnwise_data") = py::none(); @@ -3029,7 +3038,7 @@ std::pair FlexQuantizer::convert_and_update_tensor( } void FlexQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { + const std::optional& noop_flag) { if (input.numel() == 0) { return; } @@ -3057,13 +3066,13 @@ void FlexQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, // it actually needs. constexpr int kNumArgs = 7; const std::array nvte_args = { - input.get_rowwise_data(), // mX - out.get_rowwise_data(), // mO_row - out.get_rowwise_scale_inv(), // mS_row - out.get_amax(), // mA_row (primary / row-wise amax) - out.get_columnwise_data(), // mO_col - out.get_columnwise_scale_inv(), // mS_col - out.get_columnwise_amax(), // mA_col + input.get_rowwise_data(), // mX + out.get_rowwise_data(), // mO_row + out.get_rowwise_scale_inv(), // mS_row + out.get_amax(), // mA_row (primary / row-wise amax) + out.get_columnwise_data(), // mO_col + out.get_columnwise_scale_inv(), // mS_col + out.get_columnwise_amax(), // mA_col }; // Named locals: each DLTensorWrapper owns the synthesized DLTensor shape/ @@ -3093,11 +3102,11 @@ void FlexQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, // quantize_mxfp8.cuh's per-call cudaMemsetAsync). Scales are 1-byte // (E8M0/E4M3), so the element count is the byte count. if (this->optimize_for_gemm) { - const auto &scale_row = dltensors[2]; // mS_row + const auto& scale_row = dltensors[2]; // mS_row if (scale_row) { NVTE_CHECK_CUDA(cudaMemsetAsync(scale_row->data, 0, scale_row->numel(), stream)); } - const auto &scale_col = dltensors[5]; // mS_col + const auto& scale_col = dltensors[5]; // mS_col if (scale_col) { NVTE_CHECK_CUDA(cudaMemsetAsync(scale_col->data, 0, scale_col->numel(), stream)); } @@ -3112,16 +3121,16 @@ void FlexQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, // when absent), packed by tvm-ffi's operator() and kept alive through the // whole call expression. The wrappers back the views and outlive the call. // (mX, mO_row, mS_row, mA_row, mO_col, mS_col, mA_col, rng_state, stream) - call_tvm_ffi(this->quantize_func, - to_ffi_arg(dltensors[0]), // input tensor - to_ffi_arg(dltensors[1]), // rowwise quantized data - to_ffi_arg(dltensors[2]), // rowwise scale_inv - to_ffi_arg(dltensors[3]), // rowwise amax - to_ffi_arg(dltensors[4]), // columnwise quantized data - to_ffi_arg(dltensors[5]), // columnwise scale_inv - to_ffi_arg(dltensors[6]), // columnwise amax - to_ffi_arg(rng_w), // rng_state (used for stochastic rounding if you need it) - stream_handle // CUDA stream handle as int64 + call_tvm_ffi(this->quantize_func, + to_ffi_arg(dltensors[0]), // input tensor + to_ffi_arg(dltensors[1]), // rowwise quantized data + to_ffi_arg(dltensors[2]), // rowwise scale_inv + to_ffi_arg(dltensors[3]), // rowwise amax + to_ffi_arg(dltensors[4]), // columnwise quantized data + to_ffi_arg(dltensors[5]), // columnwise scale_inv + to_ffi_arg(dltensors[6]), // columnwise amax + to_ffi_arg(rng_w), // rng_state (used for stochastic rounding if you need it) + stream_handle // CUDA stream handle as int64 ); } @@ -3177,13 +3186,13 @@ void FlexQuantizer::dequantize(const TensorWrapper& input, TensorWrapper& out) { } std::vector FlexQuantizer::get_scale_shape(const std::vector& shape, DType dtype, - bool columnwise) const { + bool columnwise) const { const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); return get_scale_shape(flat_first_dim, flat_last_dim, dtype, columnwise); } std::vector FlexQuantizer::get_scale_shape(size_t flat_first_dim, size_t flat_last_dim, - DType dtype, bool columnwise) const { + DType dtype, bool columnwise) const { // Each direction uses its own format's scale-factor shape, mirroring the // per-format MXFP8Quantizer / NVFP4Quantizer get_scale_shape implementations. if (is_fp8_dtype(dtype)) { diff --git a/transformer_engine/pytorch/csrc/tvm_ffi_bridge.h b/transformer_engine/pytorch/csrc/tvm_ffi_bridge.h index 9da1246210..06db1da25d 100644 --- a/transformer_engine/pytorch/csrc/tvm_ffi_bridge.h +++ b/transformer_engine/pytorch/csrc/tvm_ffi_bridge.h @@ -7,13 +7,6 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_TVM_FFI_BRIDGE_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_TVM_FFI_BRIDGE_H_ -#include -#include -#include -#include -#include -#include - #include #include #include @@ -24,6 +17,13 @@ #include #include +#include +#include +#include +#include +#include +#include + #include "transformer_engine/transformer_engine.h" #include "util/logging.h" @@ -39,17 +39,27 @@ namespace py = pybind11; // support) requires this manual mapping. inline DLDataType convert_to_dltype(NVTEDType type) { switch (type) { - case kNVTEFloat32: return DLDataType{kDLFloat, 32, 1}; - case kNVTEFloat16: return DLDataType{kDLFloat, 16, 1}; - case kNVTEBFloat16: return DLDataType{kDLBfloat, 16, 1}; - case kNVTEByte: return DLDataType{kDLUInt, 8, 1}; - case kNVTEInt32: return DLDataType{kDLInt, 32, 1}; - case kNVTEInt64: return DLDataType{kDLInt, 64, 1}; + case kNVTEFloat32: + return DLDataType{kDLFloat, 32, 1}; + case kNVTEFloat16: + return DLDataType{kDLFloat, 16, 1}; + case kNVTEBFloat16: + return DLDataType{kDLBfloat, 16, 1}; + case kNVTEByte: + return DLDataType{kDLUInt, 8, 1}; + case kNVTEInt32: + return DLDataType{kDLInt, 32, 1}; + case kNVTEInt64: + return DLDataType{kDLInt, 64, 1}; // FP8 / E8M0 → raw 1-byte uint; the kernel interprets the bits. - case kNVTEFloat8E4M3: return DLDataType{kDLUInt, 8, 1}; - case kNVTEFloat8E5M2: return DLDataType{kDLUInt, 8, 1}; - case kNVTEFloat8E8M0: return DLDataType{kDLUInt, 8, 1}; - default: NVTE_ERROR("unsupported NVTEDType: ", static_cast(type)); + case kNVTEFloat8E4M3: + return DLDataType{kDLUInt, 8, 1}; + case kNVTEFloat8E5M2: + return DLDataType{kDLUInt, 8, 1}; + case kNVTEFloat8E8M0: + return DLDataType{kDLUInt, 8, 1}; + default: + NVTE_ERROR("unsupported NVTEDType: ", static_cast(type)); } } @@ -80,21 +90,21 @@ class DLTensorWrapper : public DLTensor { // contiguous strides (TE tensors are always contiguous). DLTensorWrapper(const NVTEBasicTensor &tensor, int32_t device_index) { const int n = static_cast(tensor.shape.ndim); - shape_buf_ = std::make_unique(n); + shape_buf_ = std::make_unique(n); strides_buf_ = std::make_unique(n); int64_t stride = 1; for (int i = n - 1; i >= 0; --i) { - shape_buf_[i] = static_cast(tensor.shape.data[i]); + shape_buf_[i] = static_cast(tensor.shape.data[i]); strides_buf_[i] = stride; stride *= shape_buf_[i]; } - this->numel_ = stride; // product of all dims - this->data = tensor.data_ptr; - this->device = DLDevice{kDLCUDA, device_index}; - this->ndim = n; - this->dtype = convert_to_dltype(tensor.dtype); - this->shape = shape_buf_.get(); - this->strides = strides_buf_.get(); + this->numel_ = stride; // product of all dims + this->data = tensor.data_ptr; + this->device = DLDevice{kDLCUDA, device_index}; + this->ndim = n; + this->dtype = convert_to_dltype(tensor.dtype); + this->shape = shape_buf_.get(); + this->strides = strides_buf_.get(); this->byte_offset = 0; } @@ -158,8 +168,7 @@ inline tvm::ffi::Optional to_ffi_arg( template inline tvm::ffi::Any call_tvm_ffi(const std::string &fn_name, Args &&...args) { std::optional fn = tvm::ffi::Function::GetGlobal(fn_name); - NVTE_CHECK(fn.has_value(), - "No tvm-ffi kernel registered under '", fn_name, + NVTE_CHECK(fn.has_value(), "No tvm-ffi kernel registered under '", fn_name, "'. This name is the quantizer's cache key, which encodes the " "kernel's compile-time (constexpr) signature; a miss means the " "registered kernel's constexpr guarantee does not match what is " diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 3758443d53..4c4dcff218 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -202,7 +202,8 @@ TensorWrapper NVTETensorFromFlexTensor(py::handle tensor, Quantizer *quantizer) const DType dtype_column = tensor.attr("_dtype_column").cast(); const auto &columnwise_data = tensor.attr("_columnwise_data").cast(); const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); - ret.set_columnwise_data(columnwise_data.data_ptr(), dtype_column, getTensorShape(columnwise_data)); + ret.set_columnwise_data(columnwise_data.data_ptr(), dtype_column, + getTensorShape(columnwise_data)); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, getTensorShape(scale_inv)); if (is_fp4_dtype(dtype_column)) { const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); diff --git a/transformer_engine/pytorch/tensor/flex_tensor.py b/transformer_engine/pytorch/tensor/flex_tensor.py index 0640eaf0fe..bea0fbe5bb 100644 --- a/transformer_engine/pytorch/tensor/flex_tensor.py +++ b/transformer_engine/pytorch/tensor/flex_tensor.py @@ -24,6 +24,7 @@ aten = torch.ops.aten + class FlexQuantizer(Quantizer): """Builder class for Flex tensors that are quantized in potentially both directions @@ -184,11 +185,7 @@ def copy(self) -> FlexQuantizer: return quantizer def update_quantized( - self, - src: torch.Tensor, - dst: QuantizedTensor, - *, - noop_flag = None + self, src: torch.Tensor, dst: QuantizedTensor, *, noop_flag=None ) -> QuantizedTensor: assert isinstance(dst, FlexTensor), f"Cannot store quantized MXFP8 in {type(dst)} type." @@ -225,10 +222,7 @@ def calibrate(self, tensor: torch.Tensor) -> None: pass # Calibration is no-op since this supports only blockwise quantization, which doesn't require calibration. def get_scale_shape( - self, - shape: Iterable[int], - columnwise: bool, - dtype: TE_DType + self, shape: Iterable[int], columnwise: bool, dtype: TE_DType ) -> Tuple[int, int]: """Calculate the shape of the scaling tensor for Flex quantization. @@ -265,7 +259,9 @@ def get_scale_shape( elif FlexTensor.is_mxfp8_dtype(dtype): if columnwise: return ( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple( + math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4 + ), round_up_to_nearest_multiple(shape[-1], 128), ) return ( @@ -315,6 +311,7 @@ def _get_compatible_recipe(self) -> Union[Recipe, None]: # TODO: really? return None # Flex quantizer does not have a specific compatible recipe since it's orthogonal to the choice of recipe. + class FlexTensor(FlexTensorStorage, QuantizedTensor): """Tensor class for flex tensors with quantization in both directions. @@ -422,7 +419,7 @@ def clone(self) -> FlexTensor: { "rowwise_data": rowwise_data, "columnwise_data": columnwise_data, - } + }, ) def view(self, *shape: Tuple[int]) -> FlexTensor: @@ -456,9 +453,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # pylint: disable=missing-function-docstring # TODO: handle aten ops (view/copy_/split/...) per direction, # following MXFP8Tensor.__torch_dispatch__. - raise NotImplementedError( - f"FlexTensor.__torch_dispatch__ does not support {func} yet" - ) + raise NotImplementedError(f"FlexTensor.__torch_dispatch__ does not support {func} yet") # ------------------------------------------------------------------ # FSDP2 is not supported yet. Define the hooks so any accidental @@ -653,6 +648,7 @@ def _flex_rowwise_byte_shape(dtype: Optional[TE_DType], shape) -> tuple: return shape[:-1] + (shape[-1] // 2,) return tuple(shape) + def _flex_columnwise_byte_shape(dtype: Optional[TE_DType], shape) -> tuple: """Physical column-wise buffer shape for a logical ``shape``. diff --git a/transformer_engine/pytorch/tensor/storage/flex_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/flex_tensor_storage.py index 52d65b81d1..bbe9c4b0f0 100644 --- a/transformer_engine/pytorch/tensor/storage/flex_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/flex_tensor_storage.py @@ -2,8 +2,7 @@ # # See LICENSE for license information. -"""TODO: write comments -""" +"""TODO: write comments""" from __future__ import annotations from typing import Optional, Dict, Any, Tuple, Union @@ -13,16 +12,17 @@ import torch import transformer_engine_torch as tex # pylint: disable=unused-import -from transformer_engine_torch import ( - DType as TE_DType -) +from transformer_engine_torch import DType as TE_DType from ...quantized_tensor import QuantizedTensorStorage, Quantizer -from ...constants import TE_DType as torch_to_transformer_engine_dtype # pylint: disable=unused-import +from ...constants import ( + TE_DType as torch_to_transformer_engine_dtype, +) # pylint: disable=unused-import from ...utils import _empty_tensor, canonicalize_shape + class _FromFlexFunc(torch.autograd.Function): """Cast from MXFP8 to other dtype""" @@ -31,19 +31,18 @@ def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused tensor: FlexTensorStorage, dtype: torch.dtype, - quantizer: Quantizer + quantizer: Quantizer, ) -> torch.Tensor: # pylint: disable=missing-function-docstring if tensor._rowwise_data is not None and tensor._rowwise_data.numel() == 0: return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) if tensor._columnwise_data is not None and tensor._columnwise_data.numel() == 0: return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) - + if tensor._rowwise_data is not None or tensor._columnwise_data is not None: return tex.dequantize_with_quantizer(tensor, dtype, quantizer) raise ValueError("Cannot dequantize Flex tensor with no data") - @staticmethod def backward( _ctx: torch.autograd.function.FunctionCtx, # unused diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index baf90e7a4a..8613de7c27 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -971,6 +971,7 @@ def convert_to_torch_tensor(tensor: Union[_WeakRefTensor, torch.Tensor]) -> torc "Valid types are: torch.Tensor, tuple, list, dict, int, float, bool, and None." ) + def canonicalize_shape(shape, cur_shape): if not isinstance(shape, Iterable): shape = [shape] @@ -983,4 +984,4 @@ def canonicalize_shape(shape, cur_shape): if d == -1: shape[i] = d_inferred break - return shape \ No newline at end of file + return shape