From 668a4b77d7be51aeb692f87a3d7816b475ac214c Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Wed, 15 Apr 2026 09:13:07 -0700 Subject: [PATCH 1/6] [Hadamard] Cache offsets/first_dims in smem via cp.async for graph-safe kernel Use LDGSTS (cp.async) in the DMA warp to load offsets and first_dims arrays from global memory into shared memory, replacing direct global reads in the epilogue/row-quant warps. Adds cpasync_barrier for the non-RHTColQuant path and smem_offsets/smem_first_dims fields to SharedStorage. Includes offset-caching unit tests and swizzle benchmark. Signed-off-by: Siddhartha Raman S --- .../pytorch/nvfp4/bench_graph_safe_swizzle.py | 26 ++ ...cast_col_hadamard_transform_cast_fusion.cu | 100 +++++- .../hadamard_transform/test_offset_caching.cu | 313 ++++++++++++++++++ 3 files changed, 431 insertions(+), 8 deletions(-) create mode 100644 tests/pytorch/nvfp4/bench_graph_safe_swizzle.py create mode 100644 transformer_engine/common/hadamard_transform/test_offset_caching.cu diff --git a/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py b/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py new file mode 100644 index 0000000000..29e30b7727 --- /dev/null +++ b/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py @@ -0,0 +1,26 @@ +import torch +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer + +M, N = 8192, 7168 # your actual shape +x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") +split_sections = torch.tensor([128] * (M // 128), dtype=torch.int64, device="cuda") + +for optimize_for_gemm in [False, True]: + q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + q.optimize_for_gemm = optimize_for_gemm + + # warmup + for _ in range(10): + tex.group_quantize(x, q, split_sections.shape[0], split_sections) + torch.cuda.synchronize() + + # time + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(100): + tex.group_quantize(x, q, split_sections.shape[0], split_sections) + end.record() + torch.cuda.synchronize() + print(f"optimize_for_gemm={optimize_for_gemm}: {start.elapsed_time(end) / 100 * 1000:.1f} μs") diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu index 0c3a5e9299..568c876cfa 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -162,6 +162,9 @@ struct SharedStorage { alignas(16) AccumulatorPipelineStorage accumulator; alignas(16) MainloopPipelineStorage mainloop; alignas(16) cute::uint64_t tma_barrier[1]; + // Used only when kEnableRHTColQuant=false: signals cp.async completion of offsets/first_dims + // to the row quant warp without going through tma_barrier[0] (which is never signaled in that path). + alignas(16) cute::uint64_t cpasync_barrier[1]; alignas(16) SchedPipelineStorage sched; alignas(16) SchedThrottlePipelineStorage sched_throttle; alignas(16) int32_t atomic_tile_id[SchedulerPipelineStageCount_]; @@ -169,6 +172,8 @@ struct SharedStorage { alignas(16) float global_d_amax[kMaxTensorsPerKernel]; uint32_t atomic_tile_counter[SchedulerPipelineStageCount_]; uint32_t tmem_base_ptr; + alignas(8) int64_t smem_offsets[kMaxTensorsPerKernel + 1]; + alignas(8) int64_t smem_first_dims[kMaxTensorsPerKernel]; }; // Main RHT GEMM kernel entry -- highly templated for flexible architecture/config support @@ -376,6 +381,13 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g EpilogueUnrollFactor, SchedulerPipelineStageCount>; SharedStorage &shared_storage = *reinterpret_cast(shared_memory); + // offsets_smem and first_dims_smem point to the smem cache populated by the DMA warp via + // LDGSTS (cp.async). tma_barrier[0] is initialized with 2 arrivals: the DMA warp arrives + // after cp.async.wait_all (ensuring offsets/first_dims are in smem), and TMA B hardware + // arrives when the Hadamard matrix load completes. Epilogue warps wait on tma_barrier[0]. + const int64_t *const offsets_smem = shared_storage.smem_offsets; + const int64_t *const first_dims_smem = shared_storage.smem_first_dims; + // Compute the number of tiles in M and N after tiling and assign scheduler uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); uint32_t tiles_in_n = uint32_t(size(ceil_div(sum_token_dims, size<2>(epilogue_tiler)))); @@ -500,14 +512,66 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g cutlass::make_producer_start_state(); if (warp_idx == 2 && elect_one_sync()) { - cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + if constexpr (kEnableRHTColQuant) { + // Two arrivals required: (1) DMA warp after cp.async completes, (2) TMA B hardware completion. + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 2); + } else { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + cute::initialize_barrier(shared_storage.cpasync_barrier[0], /* num_threads */ 1); + } } __syncthreads(); + if (threadIdx.x == 0) printf("[DBG] block=%d after barrier init\n", blockIdx.x); // Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer if (is_dma_warp) { // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). cutlass::arch::warpgroup_reg_dealloc<32>(); + if (elect_one_sync()) printf("[DBG] block=%d DMA warp: entered\n", blockIdx.x); + + // Use LDGSTS (cp.async) in the DMA warp to load offsets and first_dims into shared memory. + // For kEnableRHTColQuant=true: after cp.async completes, the DMA warp provides the 1st of + // 2 arrivals at tma_barrier[0] (with release semantics). TMA B hardware completion provides + // the 2nd arrival, firing the barrier. Epilogue warps wait on tma_barrier[0] before reading + // offsets_smem/first_dims_smem. + // For kEnableRHTColQuant=false: cpasync_barrier[0] is used instead. + constexpr int kWarpSize = 32; + const int local_tidx = threadIdx.x % kWarpSize; + if (elect_one_sync()) printf("[DBG] block=%d DMA warp: before offsets loop, num_tensors=%zu local_tidx=%d\n", blockIdx.x, num_tensors, (int)local_tidx); + auto async_op = cute::SM80_CP_ASYNC_CACHEALWAYS{}; + for (size_t i = local_tidx; i <= num_tensors; i += kWarpSize) { + async_op.copy(offsets[i], shared_storage.smem_offsets[i]); + } + if (elect_one_sync()) printf("[DBG] block=%d DMA warp: offsets loop done\n", blockIdx.x); + for (size_t i = local_tidx; i < num_tensors; i += kWarpSize) { + async_op.copy(first_dims[i], shared_storage.smem_first_dims[i]); + } + if (elect_one_sync()) printf("[DBG] block=%d DMA warp: both loops done\n", blockIdx.x); + if constexpr (kEnableRHTColQuant) { + // Wait for all cp.async offsets/first_dims copies to complete, then provide the + // 1st of 2 required arrivals at tma_barrier[0] (with release semantics so consumers + // see the smem writes). The 2nd arrival comes from TMA B hardware completion below. + asm volatile("cp.async.commit_group;\n" ::); + asm volatile("cp.async.wait_all;\n" ::); + __threadfence_block(); + if (elect_one_sync()) { + transformer_engine::ptx::mbarrier_arrive(&shared_storage.tma_barrier[0]); + printf("[DBG] block=%d DMA warp: tma_barrier arrived (1/2)\n", blockIdx.x); + } + } else { + // No TMA B in this path. Block until all cp.async ops issued above are complete, then + // signal cpasync_barrier[0] so the row quant warp can safely read offsets_smem. + if (elect_one_sync()) printf("[DBG] block=%d DMA warp: before cp.async.wait_all\n", blockIdx.x); + asm volatile("cp.async.commit_group;\n" ::); + asm volatile("cp.async.wait_all;\n" ::); + __threadfence_block(); + if (elect_one_sync()) { + printf("[DBG] block=%d DMA warp: signaling cpasync_barrier\n", blockIdx.x); + transformer_engine::ptx::mbarrier_arrive(&shared_storage.cpasync_barrier[0]); + printf("[DBG] block=%d DMA warp: cpasync_barrier signaled\n", blockIdx.x); + } + } + // Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory. Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); @@ -551,6 +615,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g kTmaRhtTensorTransactionBytes); copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), tBsB(_, 0)); + printf("[DBG] block=%d DMA warp: TMA B issued (2nd arrival pending)\n", blockIdx.x); } } @@ -704,14 +769,16 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; } + // tma_barrier[0] completes after both B tensor TMA and offsets cp.async finish. + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); // TODO(zhongbo): double check the logic here int group_idx = get_current_tensor_id( shape_rep, num_tensors, (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, - packed_N, M, offsets); + packed_N, M, offsets_smem); // Determine quantization scale factor layouts/output splits for this group TSFDLayout sfd_layout; - int cur_N = static_cast(first_dims[group_idx]); + int cur_N = static_cast(first_dims_smem[group_idx]); if constexpr (kEnableSwizzleSFOutput) { sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); } else { @@ -772,7 +839,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g // TODO(zhongbo): double check the logic here int cur_group_idx = get_current_tensor_id( - shape_rep, num_tensors, global_tile_n_offset * M, packed_N, M, offsets); + shape_rep, num_tensors, global_tile_n_offset * M, packed_N, M, offsets_smem); if (cur_group_idx != group_idx) { group_idx = cur_group_idx; @@ -786,7 +853,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g global_decode_scale = 1.0f / global_encode_scale; global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; // TODO(zhongbo): double check the logic here - cur_N = first_dims[group_idx]; + cur_N = first_dims_smem[group_idx]; if constexpr (kEnableSwizzleSFOutput) { sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); @@ -938,6 +1005,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g } else if (is_epilogue_row_quant_warp) { // Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage. cutlass::arch::warpgroup_reg_alloc<136>(); + if (threadIdx.x == 256) printf("[DBG] block=%d row_quant: warp entered\n", blockIdx.x); if constexpr (kEnableRowQuant) { using S2RVectorType = uint128_t; @@ -997,11 +1065,22 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g // Will result in barrier_id=10 passed to bar.sync instr as cutlass adds 8 // in order to go over the reserved named barrier count. constexpr int row_quant_barrier_id = 2; + if (threadIdx.x == 256) printf("[DBG] block=%d row_quant: before NamedBarrier::sync\n", blockIdx.x); cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); + if (threadIdx.x == 256) printf("[DBG] block=%d row_quant: after NamedBarrier::sync, before wait_barrier\n", blockIdx.x); + // Wait until offsets/first_dims cp.async copies are visible in smem. + // When kEnableRHTColQuant=true, tma_barrier[0] covers both B TMA and cp.async. + // When kEnableRHTColQuant=false, cpasync_barrier[0] covers cp.async only. + if constexpr (kEnableRHTColQuant) { + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*phase_bit*/); + } else { + cute::wait_barrier(shared_storage.cpasync_barrier[0], 0 /*phase_bit*/); + } + if (threadIdx.x == 256) printf("[DBG] block=%d row_quant: after wait_barrier, group_idx computation\n", blockIdx.x); int group_idx = get_current_tensor_id( shape_rep, num_tensors, (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, - packed_N, M, offsets); + packed_N, M, offsets_smem); float a_global_amax_val = shared_storage.global_a_amax[group_idx]; // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} static constexpr float fp4_max = 6.0f; @@ -1023,7 +1102,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); int cur_group_idx = get_current_tensor_id( - shape_rep, num_tensors, global_tile_n_offset * M, packed_N, M, offsets); + shape_rep, num_tensors, global_tile_n_offset * M, packed_N, M, offsets_smem); if (cur_group_idx != group_idx) { group_idx = cur_group_idx; a_global_amax_val = shared_storage.global_a_amax[group_idx]; @@ -1259,10 +1338,15 @@ void group_row_col_rht_gemm_ntt_w_sfc_graph_safe( static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes static int constexpr kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes; + // smem_offsets: (kMaxTensorsPerKernel+1) x int64, smem_first_dims: kMaxTensorsPerKernel x int64 + // Note: tma_barrier[1] and cpasync_barrier[1] (8 bytes each) are not counted here; their impact + // on kMaxStages is negligible (<0.01% of the 232KB budget). + static int constexpr kSmemOffsetCacheBytes = + sizeof(int64_t) * (kMaxTensorsPerKernel + 1) + sizeof(int64_t) * kMaxTensorsPerKernel; static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + SchedulerPipelineBytes + TmemBasePtrsBytes + TmemDeallocBytes + BTensorBytes + - AccPipelineBytes; // Reserve for barriers and other uses + AccPipelineBytes + kSmemOffsetCacheBytes; // Reserve for barriers and other uses static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; auto sP = Int{}; // SMEM pipelines diff --git a/transformer_engine/common/hadamard_transform/test_offset_caching.cu b/transformer_engine/common/hadamard_transform/test_offset_caching.cu new file mode 100644 index 0000000000..fbf9bf51bb --- /dev/null +++ b/transformer_engine/common/hadamard_transform/test_offset_caching.cu @@ -0,0 +1,313 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +// Tests that caching offsets/first_dims from gmem into smem via cp.async +// produces identical results to reading directly from gmem in +// get_current_tensor_id(). + +#include +#include + +#include +#include + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +#define CUDA_CHECK(expr) \ + do { \ + cudaError_t _e = (expr); \ + ASSERT_EQ(_e, cudaSuccess) << cudaGetErrorString(_e); \ + } while (0) + +// --------------------------------------------------------------------------- +// Device-side implementation (mirrors the file under test) +// --------------------------------------------------------------------------- + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +// Exact copy of get_current_tensor_id from the file under test +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t current_offset, const size_t first_logical_dim, + const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { + if (shape_rep == SAME_BOTH_DIMS) { + const size_t current_row = current_offset / last_logical_dim; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + size_t low = 0, hi = num_tensors; + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + if (mid_offset <= current_offset) low = mid + 1; + else hi = mid; + } + return (low == 0) ? 0 : (low - 1); + } +} + +// --------------------------------------------------------------------------- +// Test kernel +// +// For every query offset in `queries[]`: +// result_gmem[i] = get_current_tensor_id(..., gmem offsets pointer) +// result_smem[i] = get_current_tensor_id(..., smem-cached offsets pointer) +// +// The smem path uses the same cp.async + wait_all + __syncthreads() sequence +// as the production kernel. +// --------------------------------------------------------------------------- + +constexpr int kMaxTensors = 64; + +struct SmemStorage { + int64_t offsets[kMaxTensors + 1]; +}; + +__global__ void test_offset_caching_kernel( + const int64_t *gmem_offsets, // [num_tensors+1] + const size_t *queries, // [num_queries] — offsets to look up + size_t *result_gmem, // [num_queries] + size_t *result_smem, // [num_queries] + size_t num_tensors, + size_t num_queries, + ShapeRepresentation shape_rep, + size_t first_logical_dim, + size_t last_logical_dim) { + + extern __shared__ char smem_raw[]; + SmemStorage &smem = *reinterpret_cast(smem_raw); + + // ---- cooperative cp.async load (mirrors production kernel lines 385-397) ---- + for (size_t i = threadIdx.x; i <= num_tensors; i += blockDim.x) { + uint32_t dst = static_cast(__cvta_generic_to_shared(&smem.offsets[i])); + asm volatile("cp.async.ca.shared.global [%0], [%1], 8;\n" + :: "r"(dst), + "l"(reinterpret_cast(&gmem_offsets[i]))); + } + asm volatile("cp.async.commit_group;\n" ::); + asm volatile("cp.async.wait_all;\n" ::); + __syncthreads(); + + const int64_t *const offsets_smem = smem.offsets; + + // ---- each thread handles one query ---------------------------------------- + for (size_t q = threadIdx.x; q < num_queries; q += blockDim.x) { + size_t offset = queries[q]; + + result_gmem[q] = get_current_tensor_id( + shape_rep, num_tensors, offset, + first_logical_dim, last_logical_dim, gmem_offsets); + + result_smem[q] = get_current_tensor_id( + shape_rep, num_tensors, offset, + first_logical_dim, last_logical_dim, offsets_smem); + } +} + +// --------------------------------------------------------------------------- +// Host-side reference — pure C++, no CUDA +// --------------------------------------------------------------------------- + +static size_t ref_get_tensor_id(ShapeRepresentation shape_rep, + size_t num_tensors, + size_t current_offset, + size_t first_logical_dim, + size_t last_logical_dim, + const std::vector &offsets) { + if (shape_rep == SAME_BOTH_DIMS) { + size_t current_row = current_offset / last_logical_dim; + size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + size_t low = 0, hi = num_tensors; + while (low < hi) { + size_t mid = low + (hi - low) / 2; + if (static_cast(offsets[mid]) <= current_offset) low = mid + 1; + else hi = mid; + } + return (low == 0) ? 0 : (low - 1); + } +} + +// --------------------------------------------------------------------------- +// Helper: run kernel + compare +// --------------------------------------------------------------------------- + +static void run_test(ShapeRepresentation shape_rep, + const std::vector &offsets_host, + const std::vector &queries_host, + size_t first_logical_dim, + size_t last_logical_dim) { + + const size_t num_tensors = offsets_host.size() - 1; // offsets has num_tensors+1 entries + const size_t num_queries = queries_host.size(); + + // --- allocate device memory ----------------------------------------------- + int64_t *d_offsets = nullptr; + size_t *d_queries = nullptr, *d_result_gmem = nullptr, *d_result_smem = nullptr; + + CUDA_CHECK(cudaMalloc(&d_offsets, (num_tensors + 1) * sizeof(int64_t))); + CUDA_CHECK(cudaMalloc(&d_queries, num_queries * sizeof(size_t))); + CUDA_CHECK(cudaMalloc(&d_result_gmem, num_queries * sizeof(size_t))); + CUDA_CHECK(cudaMalloc(&d_result_smem, num_queries * sizeof(size_t))); + + CUDA_CHECK(cudaMemcpy(d_offsets, offsets_host.data(), + (num_tensors + 1) * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_queries, queries_host.data(), + num_queries * sizeof(size_t), cudaMemcpyHostToDevice)); + + // --- launch --------------------------------------------------------------- + int smem_bytes = sizeof(SmemStorage); + test_offset_caching_kernel<<<1, 128, smem_bytes>>>( + d_offsets, d_queries, + d_result_gmem, d_result_smem, + num_tensors, num_queries, + shape_rep, first_logical_dim, last_logical_dim); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + // --- copy results back ---------------------------------------------------- + std::vector h_gmem(num_queries), h_smem(num_queries); + CUDA_CHECK(cudaMemcpy(h_gmem.data(), d_result_gmem, num_queries * sizeof(size_t), + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_smem.data(), d_result_smem, num_queries * sizeof(size_t), + cudaMemcpyDeviceToHost)); + + // --- verify: gmem == smem == host_ref ------------------------------------- + for (size_t q = 0; q < num_queries; ++q) { + size_t ref = ref_get_tensor_id(shape_rep, num_tensors, queries_host[q], + first_logical_dim, last_logical_dim, offsets_host); + EXPECT_EQ(h_gmem[q], ref) + << "query=" << queries_host[q] << " gmem result mismatch at q=" << q; + EXPECT_EQ(h_smem[q], ref) + << "query=" << queries_host[q] << " smem result mismatch at q=" << q; + EXPECT_EQ(h_gmem[q], h_smem[q]) + << "query=" << queries_host[q] << " gmem != smem at q=" << q; + } + + cudaFree(d_offsets); + cudaFree(d_queries); + cudaFree(d_result_gmem); + cudaFree(d_result_smem); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +// Case 1: SAME_BOTH_DIMS — 4 tensors, equal rows each +TEST(OffsetCaching, SameBothDims) { + // 4 tensors, each with 128 rows, hidden=256 + // offsets (in elements): 0, 128*256, 2*128*256, 3*128*256, 4*128*256 + const size_t hidden = 256; + const size_t rows_per = 128; + const size_t N = 4; + + std::vector offsets(N + 1); + for (size_t i = 0; i <= N; ++i) + offsets[i] = static_cast(i * rows_per * hidden); + + // Query every row's first element + std::vector queries; + for (size_t t = 0; t < N; ++t) + for (size_t r = 0; r < rows_per; ++r) + queries.push_back(static_cast(offsets[t]) + r * hidden); + + run_test(SAME_BOTH_DIMS, offsets, queries, + /*first_logical_dim=*/N * rows_per, + /*last_logical_dim=*/hidden); +} + +// Case 2: VARYING_FIRST_DIM — tensors with different numbers of rows +TEST(OffsetCaching, VaryingFirstDim) { + // 3 tensors with row counts 128, 256, 192; hidden=512 + const size_t hidden = 512; + const std::vector row_counts = {128, 256, 192}; + const size_t N = row_counts.size(); + + std::vector offsets(N + 1); + offsets[0] = 0; + for (size_t i = 0; i < N; ++i) + offsets[i + 1] = offsets[i] + static_cast(row_counts[i] * hidden); + + // Query: first element, last element, and middle element of each tensor + std::vector queries; + for (size_t t = 0; t < N; ++t) { + size_t start = static_cast(offsets[t]); + size_t end = static_cast(offsets[t + 1]) - 1; + size_t mid = (start + end) / 2; + queries.push_back(start); + queries.push_back(mid); + queries.push_back(end); + } + + run_test(VARYING_FIRST_DIM, offsets, queries, + /*first_logical_dim=*/0, // unused for binary search path + /*last_logical_dim=*/hidden); +} + +// Case 3: boundary — query lands exactly on an offset boundary +TEST(OffsetCaching, ExactBoundary) { + const size_t hidden = 128; + const std::vector row_counts = {128, 128, 256, 64}; + const size_t N = row_counts.size(); + + std::vector offsets(N + 1); + offsets[0] = 0; + for (size_t i = 0; i < N; ++i) + offsets[i + 1] = offsets[i] + static_cast(row_counts[i] * hidden); + + // Query exactly at each offset boundary (should map to the tensor that starts there) + std::vector queries; + for (size_t t = 0; t < N; ++t) + queries.push_back(static_cast(offsets[t])); + + run_test(VARYING_FIRST_DIM, offsets, queries, + /*first_logical_dim=*/0, + /*last_logical_dim=*/hidden); +} + +// Case 4: single tensor — degenerate case +TEST(OffsetCaching, SingleTensor) { + const size_t hidden = 256; + const size_t rows = 512; + + std::vector offsets = {0, static_cast(rows * hidden)}; + std::vector queries = {0, rows * hidden / 2, rows * hidden - 1}; + + run_test(VARYING_FIRST_DIM, offsets, queries, + /*first_logical_dim=*/rows, + /*last_logical_dim=*/hidden); +} + +// Case 5: maximum tensors (kMaxTensors=64) +TEST(OffsetCaching, MaxTensors) { + const size_t hidden = 128; + const size_t rows_each = 128; + const size_t N = 64; + + std::vector offsets(N + 1); + offsets[0] = 0; + for (size_t i = 0; i < N; ++i) + offsets[i + 1] = offsets[i] + static_cast(rows_each * hidden); + + // One query per tensor + std::vector queries; + for (size_t t = 0; t < N; ++t) + queries.push_back(static_cast(offsets[t])); + + run_test(VARYING_FIRST_DIM, offsets, queries, + /*first_logical_dim=*/0, + /*last_logical_dim=*/hidden); +} From 5292eb749c3243359bf0e6c009a6b43459482e54 Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Wed, 15 Apr 2026 09:13:07 -0700 Subject: [PATCH 2/6] Add nvfp4 benchmark and test files from NVFP4_graph_safe branch - bench_search.py, bench_structural.py, bench_sweep_swizzle.py, ncu_test.py Signed-off-by: Siddhartha Raman S --- tests/pytorch/nvfp4/bench_search.py | 87 +++++++++++++++++++++ tests/pytorch/nvfp4/bench_structural.py | 52 +++++++++++++ tests/pytorch/nvfp4/bench_sweep_swizzle.py | 90 ++++++++++++++++++++++ tests/pytorch/nvfp4/ncu_test.py | 19 +++++ 4 files changed, 248 insertions(+) create mode 100644 tests/pytorch/nvfp4/bench_search.py create mode 100644 tests/pytorch/nvfp4/bench_structural.py create mode 100644 tests/pytorch/nvfp4/bench_sweep_swizzle.py create mode 100644 tests/pytorch/nvfp4/ncu_test.py diff --git a/tests/pytorch/nvfp4/bench_search.py b/tests/pytorch/nvfp4/bench_search.py new file mode 100644 index 0000000000..f301f852f9 --- /dev/null +++ b/tests/pytorch/nvfp4/bench_search.py @@ -0,0 +1,87 @@ +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +import torch +import torch.cuda.nvtx as nvtx + +N = 7168 +num_experts = 64 +ITERS = 50 + +M_VALUES = [8192, 16384, 32768, 65536, 131072] + + +def make_unequal_splits(M, num_experts): + base = M // num_experts + splits = [] + for i in range(num_experts): + if i % 2 == 0: + splits.append(base - 128) + else: + splits.append(base + 128) + # fix rounding so sum == M + diff = M - sum(splits) + splits[-1] += diff + return splits + + +def bench(fn, label, iters=ITERS): + for _ in range(10): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + nvtx.range_push(label) + start.record() + for _ in range(iters): + fn() + end.record() + nvtx.range_pop() + torch.cuda.synchronize() + us = start.elapsed_time(end) / iters * 1000 + print(f" {label}: {us:.1f} us") + return us + + +print(f"N={N}, num_experts={num_experts}") +print("-" * 60) + +for M in M_VALUES: + if M % num_experts != 0 or (M // num_experts) <= 128: + print(f"M={M}: skipped") + continue + + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + label_prefix = f"M{M}" + + print(f"\nM={M}:") + + # --- graph-safe, equal splits (O(1) division) --- + equal_splits = [M // num_experts] * num_experts + equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda") + q_eq = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + q_eq.optimize_for_gemm = False + bench( + lambda: tex.group_quantize(x, q_eq, num_experts, equal_tensor), + f"{label_prefix}_graph_safe_equal_O1", + ) + + # --- graph-safe, unequal splits (binary search) --- + unequal_splits = make_unequal_splits(M, num_experts) + unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda") + q_uneq = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + q_uneq.optimize_for_gemm = False + bench( + lambda: tex.group_quantize(x, q_uneq, num_experts, unequal_tensor), + f"{label_prefix}_graph_safe_unequal_bsearch", + ) + + # --- non-graph-safe (linear scan) --- + q_list = [ + NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + for _ in range(num_experts) + ] + bench( + lambda: tex.split_quantize(x, equal_splits, q_list), + f"{label_prefix}_non_graph_safe_linear", + ) diff --git a/tests/pytorch/nvfp4/bench_structural.py b/tests/pytorch/nvfp4/bench_structural.py new file mode 100644 index 0000000000..3eaa704e6c --- /dev/null +++ b/tests/pytorch/nvfp4/bench_structural.py @@ -0,0 +1,52 @@ +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +import torch +import torch.cuda.nvtx as nvtx + +N = 7168 +num_experts = 64 + +def make_quantizer(): + q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + q.optimize_for_gemm = True + return q + +def bench(fn, label, iters=100): + for _ in range(10): fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + nvtx.range_push(label) + start.record() + for _ in range(iters): fn() + end.record() + nvtx.range_pop() + torch.cuda.synchronize() + print(f"{label}: {start.elapsed_time(end) / iters * 1000:.1f} us") + +for M in [16384, 65536, 131072]: + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + + # 1. graph-safe + equal splits -> O(1) division (SAME_BOTH_DIMS) + equal_splits = [M // num_experts] * num_experts + equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda") + q1 = make_quantizer() + bench(lambda: tex.group_quantize(x, q1, num_experts, equal_tensor), + f"[M={M}] graph_safe_equal_O1") + + # 2. graph-safe + unequal splits -> binary search (VARYING_FIRST_DIM) + base = M // num_experts + unequal_splits = [base - 128 if i % 2 == 0 else base + 128 for i in range(num_experts)] + unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda") + q2 = make_quantizer() + bench(lambda: tex.group_quantize(x, q2, num_experts, unequal_tensor), + f"[M={M}] graph_safe_unequal_binary_search") + + # 3. non-graph-safe + linear scan (GetGroupIdx) + q_list = [NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + for _ in range(num_experts)] + bench(lambda: tex.split_quantize(x, equal_splits, q_list), + f"[M={M}] non_graph_safe_linear_scan") + + print() diff --git a/tests/pytorch/nvfp4/bench_sweep_swizzle.py b/tests/pytorch/nvfp4/bench_sweep_swizzle.py new file mode 100644 index 0000000000..6dd8e351c3 --- /dev/null +++ b/tests/pytorch/nvfp4/bench_sweep_swizzle.py @@ -0,0 +1,90 @@ +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +import torch +import torch.cuda.nvtx as nvtx + +N = 7168 +num_experts = 64 +ITERS = 50 + +M_VALUES = [8192, 16384, 32768, 65536, 131072] + + +def bench(fn, label, iters=ITERS): + # warmup + for _ in range(10): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + nvtx.range_push(label) + start.record() + for _ in range(iters): + fn() + end.record() + nvtx.range_pop() + torch.cuda.synchronize() + us = start.elapsed_time(end) / iters * 1000 + print(f" {label}: {us:.1f} us") + return us + + +print(f"N={N}, num_experts={num_experts}") +print("-" * 60) + +for M in M_VALUES: + if M % num_experts != 0: + print(f"M={M}: skipped (not divisible by num_experts={num_experts})") + continue + + rows_per_expert = M // num_experts + split_sections = [rows_per_expert] * num_experts + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + + print(f"\nM={M} ({rows_per_expert} rows/expert):") + + label_prefix = f"M{M}" + + # --- graph-safe, swizzle ON --- + q_on = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + ) + q_on.optimize_for_gemm = True + bench( + lambda: tex.group_quantize(x, q_on, num_experts, split_section_tensor), + f"{label_prefix}_graph_safe_swizzle_ON", + ) + + # --- graph-safe, swizzle OFF --- + q_off = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + ) + q_off.optimize_for_gemm = False + bench( + lambda: tex.group_quantize(x, q_off, num_experts, split_section_tensor), + f"{label_prefix}_graph_safe_swizzle_OFF", + ) + + # --- non-graph-safe --- + q_list = [ + NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + ) + for _ in range(num_experts) + ] + bench( + lambda: tex.split_quantize(x, split_sections, q_list), + f"{label_prefix}_non_graph_safe", + ) + diff --git a/tests/pytorch/nvfp4/ncu_test.py b/tests/pytorch/nvfp4/ncu_test.py new file mode 100644 index 0000000000..64319166db --- /dev/null +++ b/tests/pytorch/nvfp4/ncu_test.py @@ -0,0 +1,19 @@ +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +import torch + +M, N, num_experts = 16384, 7168, 64 +x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") +splits = [M // num_experts] * num_experts +split_tensor = torch.tensor(splits, dtype=torch.int64, device="cuda") + +# warmup +q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) +for _ in range(3): + tex.group_quantize(x, q, num_experts, split_tensor) +torch.cuda.synchronize() + +# single measured launch +tex.group_quantize(x, q, num_experts, split_tensor) +torch.cuda.synchronize() From ac18947e545a8b86224bad9d3e309d212cb6d6be Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Wed, 15 Apr 2026 09:13:07 -0700 Subject: [PATCH 3/6] Remove debug printf statements from graph-safe hadamard kernel Signed-off-by: Siddhartha Raman S --- ..._row_cast_col_hadamard_transform_cast_fusion.cu | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu index 568c876cfa..349370ce6c 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -521,13 +521,11 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g } } __syncthreads(); - if (threadIdx.x == 0) printf("[DBG] block=%d after barrier init\n", blockIdx.x); // Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer if (is_dma_warp) { // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). cutlass::arch::warpgroup_reg_dealloc<32>(); - if (elect_one_sync()) printf("[DBG] block=%d DMA warp: entered\n", blockIdx.x); // Use LDGSTS (cp.async) in the DMA warp to load offsets and first_dims into shared memory. // For kEnableRHTColQuant=true: after cp.async completes, the DMA warp provides the 1st of @@ -537,16 +535,13 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g // For kEnableRHTColQuant=false: cpasync_barrier[0] is used instead. constexpr int kWarpSize = 32; const int local_tidx = threadIdx.x % kWarpSize; - if (elect_one_sync()) printf("[DBG] block=%d DMA warp: before offsets loop, num_tensors=%zu local_tidx=%d\n", blockIdx.x, num_tensors, (int)local_tidx); auto async_op = cute::SM80_CP_ASYNC_CACHEALWAYS{}; for (size_t i = local_tidx; i <= num_tensors; i += kWarpSize) { async_op.copy(offsets[i], shared_storage.smem_offsets[i]); } - if (elect_one_sync()) printf("[DBG] block=%d DMA warp: offsets loop done\n", blockIdx.x); for (size_t i = local_tidx; i < num_tensors; i += kWarpSize) { async_op.copy(first_dims[i], shared_storage.smem_first_dims[i]); } - if (elect_one_sync()) printf("[DBG] block=%d DMA warp: both loops done\n", blockIdx.x); if constexpr (kEnableRHTColQuant) { // Wait for all cp.async offsets/first_dims copies to complete, then provide the // 1st of 2 required arrivals at tma_barrier[0] (with release semantics so consumers @@ -556,19 +551,15 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g __threadfence_block(); if (elect_one_sync()) { transformer_engine::ptx::mbarrier_arrive(&shared_storage.tma_barrier[0]); - printf("[DBG] block=%d DMA warp: tma_barrier arrived (1/2)\n", blockIdx.x); } } else { // No TMA B in this path. Block until all cp.async ops issued above are complete, then // signal cpasync_barrier[0] so the row quant warp can safely read offsets_smem. - if (elect_one_sync()) printf("[DBG] block=%d DMA warp: before cp.async.wait_all\n", blockIdx.x); asm volatile("cp.async.commit_group;\n" ::); asm volatile("cp.async.wait_all;\n" ::); __threadfence_block(); if (elect_one_sync()) { - printf("[DBG] block=%d DMA warp: signaling cpasync_barrier\n", blockIdx.x); transformer_engine::ptx::mbarrier_arrive(&shared_storage.cpasync_barrier[0]); - printf("[DBG] block=%d DMA warp: cpasync_barrier signaled\n", blockIdx.x); } } @@ -615,7 +606,6 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g kTmaRhtTensorTransactionBytes); copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), tBsB(_, 0)); - printf("[DBG] block=%d DMA warp: TMA B issued (2nd arrival pending)\n", blockIdx.x); } } @@ -1005,7 +995,6 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g } else if (is_epilogue_row_quant_warp) { // Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage. cutlass::arch::warpgroup_reg_alloc<136>(); - if (threadIdx.x == 256) printf("[DBG] block=%d row_quant: warp entered\n", blockIdx.x); if constexpr (kEnableRowQuant) { using S2RVectorType = uint128_t; @@ -1065,9 +1054,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g // Will result in barrier_id=10 passed to bar.sync instr as cutlass adds 8 // in order to go over the reserved named barrier count. constexpr int row_quant_barrier_id = 2; - if (threadIdx.x == 256) printf("[DBG] block=%d row_quant: before NamedBarrier::sync\n", blockIdx.x); cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); - if (threadIdx.x == 256) printf("[DBG] block=%d row_quant: after NamedBarrier::sync, before wait_barrier\n", blockIdx.x); // Wait until offsets/first_dims cp.async copies are visible in smem. // When kEnableRHTColQuant=true, tma_barrier[0] covers both B TMA and cp.async. @@ -1077,7 +1064,6 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g } else { cute::wait_barrier(shared_storage.cpasync_barrier[0], 0 /*phase_bit*/); } - if (threadIdx.x == 256) printf("[DBG] block=%d row_quant: after wait_barrier, group_idx computation\n", blockIdx.x); int group_idx = get_current_tensor_id( shape_rep, num_tensors, (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, packed_N, M, offsets_smem); From 3d255ab4155d7ce7cc3b2439724cbac8010f2fe6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 01:12:55 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Siddhartha Raman S --- .../pytorch/nvfp4/bench_graph_safe_swizzle.py | 44 +++--- tests/pytorch/nvfp4/bench_search.py | 110 ++++++------- tests/pytorch/nvfp4/bench_structural.py | 95 ++++++----- tests/pytorch/nvfp4/bench_sweep_swizzle.py | 63 ++++---- tests/pytorch/nvfp4/ncu_test.py | 32 ++-- ...cast_col_hadamard_transform_cast_fusion.cu | 8 +- .../hadamard_transform/test_offset_caching.cu | 148 ++++++++---------- 7 files changed, 246 insertions(+), 254 deletions(-) diff --git a/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py b/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py index 29e30b7727..0672afea24 100644 --- a/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py +++ b/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py @@ -1,26 +1,26 @@ -import torch -import transformer_engine_torch as tex -from transformer_engine.pytorch import NVFP4Quantizer - +import torch +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer + M, N = 8192, 7168 # your actual shape -x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") -split_sections = torch.tensor([128] * (M // 128), dtype=torch.int64, device="cuda") - -for optimize_for_gemm in [False, True]: - q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) - q.optimize_for_gemm = optimize_for_gemm - - # warmup - for _ in range(10): - tex.group_quantize(x, q, split_sections.shape[0], split_sections) - torch.cuda.synchronize() - +x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") +split_sections = torch.tensor([128] * (M // 128), dtype=torch.int64, device="cuda") + +for optimize_for_gemm in [False, True]: + q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + q.optimize_for_gemm = optimize_for_gemm + + # warmup + for _ in range(10): + tex.group_quantize(x, q, split_sections.shape[0], split_sections) + torch.cuda.synchronize() + # time - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(100): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(100): tex.group_quantize(x, q, split_sections.shape[0], split_sections) - end.record() - torch.cuda.synchronize() + end.record() + torch.cuda.synchronize() print(f"optimize_for_gemm={optimize_for_gemm}: {start.elapsed_time(end) / 100 * 1000:.1f} μs") diff --git a/tests/pytorch/nvfp4/bench_search.py b/tests/pytorch/nvfp4/bench_search.py index f301f852f9..7f012c5951 100644 --- a/tests/pytorch/nvfp4/bench_search.py +++ b/tests/pytorch/nvfp4/bench_search.py @@ -1,87 +1,87 @@ -import transformer_engine.pytorch as te +import transformer_engine.pytorch as te import transformer_engine_torch as tex -from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch import NVFP4Quantizer import torch -import torch.cuda.nvtx as nvtx - -N = 7168 +import torch.cuda.nvtx as nvtx + +N = 7168 num_experts = 64 -ITERS = 50 - -M_VALUES = [8192, 16384, 32768, 65536, 131072] - - +ITERS = 50 + +M_VALUES = [8192, 16384, 32768, 65536, 131072] + + def make_unequal_splits(M, num_experts): base = M // num_experts - splits = [] + splits = [] for i in range(num_experts): - if i % 2 == 0: + if i % 2 == 0: splits.append(base - 128) else: - splits.append(base + 128) + splits.append(base + 128) # fix rounding so sum == M - diff = M - sum(splits) - splits[-1] += diff + diff = M - sum(splits) + splits[-1] += diff return splits - - + + def bench(fn, label, iters=ITERS): - for _ in range(10): - fn() + for _ in range(10): + fn() torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) nvtx.range_push(label) - start.record() + start.record() for _ in range(iters): - fn() + fn() end.record() - nvtx.range_pop() + nvtx.range_pop() torch.cuda.synchronize() us = start.elapsed_time(end) / iters * 1000 - print(f" {label}: {us:.1f} us") + print(f" {label}: {us:.1f} us") return us - - + + print(f"N={N}, num_experts={num_experts}") -print("-" * 60) - +print("-" * 60) + for M in M_VALUES: if M % num_experts != 0 or (M // num_experts) <= 128: - print(f"M={M}: skipped") - continue - - x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + print(f"M={M}: skipped") + continue + + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") label_prefix = f"M{M}" - - print(f"\nM={M}:") - - # --- graph-safe, equal splits (O(1) division) --- + + print(f"\nM={M}:") + + # --- graph-safe, equal splits (O(1) division) --- equal_splits = [M // num_experts] * num_experts - equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda") + equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda") q_eq = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) - q_eq.optimize_for_gemm = False - bench( - lambda: tex.group_quantize(x, q_eq, num_experts, equal_tensor), - f"{label_prefix}_graph_safe_equal_O1", - ) - - # --- graph-safe, unequal splits (binary search) --- + q_eq.optimize_for_gemm = False + bench( + lambda: tex.group_quantize(x, q_eq, num_experts, equal_tensor), + f"{label_prefix}_graph_safe_equal_O1", + ) + + # --- graph-safe, unequal splits (binary search) --- unequal_splits = make_unequal_splits(M, num_experts) - unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda") + unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda") q_uneq = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) - q_uneq.optimize_for_gemm = False - bench( - lambda: tex.group_quantize(x, q_uneq, num_experts, unequal_tensor), - f"{label_prefix}_graph_safe_unequal_bsearch", - ) + q_uneq.optimize_for_gemm = False + bench( + lambda: tex.group_quantize(x, q_uneq, num_experts, unequal_tensor), + f"{label_prefix}_graph_safe_unequal_bsearch", + ) - # --- non-graph-safe (linear scan) --- + # --- non-graph-safe (linear scan) --- q_list = [ - NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) - for _ in range(num_experts) + NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + for _ in range(num_experts) ] - bench( + bench( lambda: tex.split_quantize(x, equal_splits, q_list), f"{label_prefix}_non_graph_safe_linear", - ) + ) diff --git a/tests/pytorch/nvfp4/bench_structural.py b/tests/pytorch/nvfp4/bench_structural.py index 3eaa704e6c..9d20da3ffd 100644 --- a/tests/pytorch/nvfp4/bench_structural.py +++ b/tests/pytorch/nvfp4/bench_structural.py @@ -1,52 +1,63 @@ -import transformer_engine.pytorch as te -import transformer_engine_torch as tex -from transformer_engine.pytorch import NVFP4Quantizer -import torch -import torch.cuda.nvtx as nvtx - +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +import torch +import torch.cuda.nvtx as nvtx + N = 7168 -num_experts = 64 - +num_experts = 64 + + def make_quantizer(): - q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) - q.optimize_for_gemm = True + q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + q.optimize_for_gemm = True return q - -def bench(fn, label, iters=100): - for _ in range(10): fn() - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) + + +def bench(fn, label, iters=100): + for _ in range(10): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - nvtx.range_push(label) + nvtx.range_push(label) start.record() - for _ in range(iters): fn() + for _ in range(iters): + fn() end.record() - nvtx.range_pop() + nvtx.range_pop() torch.cuda.synchronize() print(f"{label}: {start.elapsed_time(end) / iters * 1000:.1f} us") - -for M in [16384, 65536, 131072]: - x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") - + + +for M in [16384, 65536, 131072]: + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + # 1. graph-safe + equal splits -> O(1) division (SAME_BOTH_DIMS) - equal_splits = [M // num_experts] * num_experts - equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda") - q1 = make_quantizer() - bench(lambda: tex.group_quantize(x, q1, num_experts, equal_tensor), - f"[M={M}] graph_safe_equal_O1") - - # 2. graph-safe + unequal splits -> binary search (VARYING_FIRST_DIM) - base = M // num_experts - unequal_splits = [base - 128 if i % 2 == 0 else base + 128 for i in range(num_experts)] - unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda") + equal_splits = [M // num_experts] * num_experts + equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda") + q1 = make_quantizer() + bench( + lambda: tex.group_quantize(x, q1, num_experts, equal_tensor), f"[M={M}] graph_safe_equal_O1" + ) + + # 2. graph-safe + unequal splits -> binary search (VARYING_FIRST_DIM) + base = M // num_experts + unequal_splits = [base - 128 if i % 2 == 0 else base + 128 for i in range(num_experts)] + unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda") q2 = make_quantizer() - bench(lambda: tex.group_quantize(x, q2, num_experts, unequal_tensor), - f"[M={M}] graph_safe_unequal_binary_search") - - # 3. non-graph-safe + linear scan (GetGroupIdx) - q_list = [NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) - for _ in range(num_experts)] - bench(lambda: tex.split_quantize(x, equal_splits, q_list), - f"[M={M}] non_graph_safe_linear_scan") - - print() + bench( + lambda: tex.group_quantize(x, q2, num_experts, unequal_tensor), + f"[M={M}] graph_safe_unequal_binary_search", + ) + + # 3. non-graph-safe + linear scan (GetGroupIdx) + q_list = [ + NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + for _ in range(num_experts) + ] + bench( + lambda: tex.split_quantize(x, equal_splits, q_list), f"[M={M}] non_graph_safe_linear_scan" + ) + + print() diff --git a/tests/pytorch/nvfp4/bench_sweep_swizzle.py b/tests/pytorch/nvfp4/bench_sweep_swizzle.py index 6dd8e351c3..9f86afee78 100644 --- a/tests/pytorch/nvfp4/bench_sweep_swizzle.py +++ b/tests/pytorch/nvfp4/bench_sweep_swizzle.py @@ -2,89 +2,88 @@ import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer import torch -import torch.cuda.nvtx as nvtx - -N = 7168 +import torch.cuda.nvtx as nvtx + +N = 7168 num_experts = 64 ITERS = 50 M_VALUES = [8192, 16384, 32768, 65536, 131072] - + def bench(fn, label, iters=ITERS): - # warmup + # warmup for _ in range(10): fn() torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - nvtx.range_push(label) + nvtx.range_push(label) start.record() - for _ in range(iters): - fn() + for _ in range(iters): + fn() end.record() nvtx.range_pop() torch.cuda.synchronize() us = start.elapsed_time(end) / iters * 1000 - print(f" {label}: {us:.1f} us") + print(f" {label}: {us:.1f} us") return us - - + + print(f"N={N}, num_experts={num_experts}") print("-" * 60) -for M in M_VALUES: +for M in M_VALUES: if M % num_experts != 0: - print(f"M={M}: skipped (not divisible by num_experts={num_experts})") + print(f"M={M}: skipped (not divisible by num_experts={num_experts})") continue - rows_per_expert = M // num_experts + rows_per_expert = M // num_experts split_sections = [rows_per_expert] * num_experts - split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") - + print(f"\nM={M} ({rows_per_expert} rows/expert):") - + label_prefix = f"M{M}" # --- graph-safe, swizzle ON --- q_on = NVFP4Quantizer( - rowwise=True, + rowwise=True, columnwise=True, - with_rht=True, + with_rht=True, with_post_rht_amax=True, ) q_on.optimize_for_gemm = True - bench( + bench( lambda: tex.group_quantize(x, q_on, num_experts, split_section_tensor), - f"{label_prefix}_graph_safe_swizzle_ON", - ) + f"{label_prefix}_graph_safe_swizzle_ON", + ) - # --- graph-safe, swizzle OFF --- + # --- graph-safe, swizzle OFF --- q_off = NVFP4Quantizer( - rowwise=True, + rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True, ) q_off.optimize_for_gemm = False - bench( + bench( lambda: tex.group_quantize(x, q_off, num_experts, split_section_tensor), - f"{label_prefix}_graph_safe_swizzle_OFF", - ) + f"{label_prefix}_graph_safe_swizzle_OFF", + ) - # --- non-graph-safe --- + # --- non-graph-safe --- q_list = [ - NVFP4Quantizer( + NVFP4Quantizer( rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True, ) for _ in range(num_experts) - ] + ] bench( - lambda: tex.split_quantize(x, split_sections, q_list), + lambda: tex.split_quantize(x, split_sections, q_list), f"{label_prefix}_non_graph_safe", ) - diff --git a/tests/pytorch/nvfp4/ncu_test.py b/tests/pytorch/nvfp4/ncu_test.py index 64319166db..bff6222937 100644 --- a/tests/pytorch/nvfp4/ncu_test.py +++ b/tests/pytorch/nvfp4/ncu_test.py @@ -1,19 +1,19 @@ -import transformer_engine.pytorch as te -import transformer_engine_torch as tex -from transformer_engine.pytorch import NVFP4Quantizer -import torch - -M, N, num_experts = 16384, 7168, 64 -x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") -splits = [M // num_experts] * num_experts +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +import torch + +M, N, num_experts = 16384, 7168, 64 +x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") +splits = [M // num_experts] * num_experts split_tensor = torch.tensor(splits, dtype=torch.int64, device="cuda") - -# warmup -q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) -for _ in range(3): + +# warmup +q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) +for _ in range(3): tex.group_quantize(x, q, num_experts, split_tensor) -torch.cuda.synchronize() - -# single measured launch +torch.cuda.synchronize() + +# single measured launch tex.group_quantize(x, q, num_experts, split_tensor) -torch.cuda.synchronize() +torch.cuda.synchronize() diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu index 349370ce6c..7a23529645 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -1329,10 +1329,10 @@ void group_row_col_rht_gemm_ntt_w_sfc_graph_safe( // on kMaxStages is negligible (<0.01% of the 232KB budget). static int constexpr kSmemOffsetCacheBytes = sizeof(int64_t) * (kMaxTensorsPerKernel + 1) + sizeof(int64_t) * kMaxTensorsPerKernel; - static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + - SchedulerPipelineBytes + TmemBasePtrsBytes + - TmemDeallocBytes + BTensorBytes + - AccPipelineBytes + kSmemOffsetCacheBytes; // Reserve for barriers and other uses + static int constexpr kReservedBytes = + SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + SchedulerPipelineBytes + + TmemBasePtrsBytes + TmemDeallocBytes + BTensorBytes + AccPipelineBytes + + kSmemOffsetCacheBytes; // Reserve for barriers and other uses static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; auto sP = Int{}; // SMEM pipelines diff --git a/transformer_engine/common/hadamard_transform/test_offset_caching.cu b/transformer_engine/common/hadamard_transform/test_offset_caching.cu index fbf9bf51bb..d6db4bb990 100644 --- a/transformer_engine/common/hadamard_transform/test_offset_caching.cu +++ b/transformer_engine/common/hadamard_transform/test_offset_caching.cu @@ -18,10 +18,10 @@ // Helpers // --------------------------------------------------------------------------- -#define CUDA_CHECK(expr) \ - do { \ - cudaError_t _e = (expr); \ - ASSERT_EQ(_e, cudaSuccess) << cudaGetErrorString(_e); \ +#define CUDA_CHECK(expr) \ + do { \ + cudaError_t _e = (expr); \ + ASSERT_EQ(_e, cudaSuccess) << cudaGetErrorString(_e); \ } while (0) // --------------------------------------------------------------------------- @@ -29,29 +29,30 @@ // --------------------------------------------------------------------------- enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, + SAME_BOTH_DIMS = 0, VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, + VARYING_LAST_DIM = 2, VARYING_BOTH_DIMS = 3 }; // Exact copy of get_current_tensor_id from the file under test __device__ __forceinline__ size_t get_current_tensor_id( - const ShapeRepresentation shape_rep, const size_t num_tensors, - const size_t current_offset, const size_t first_logical_dim, - const size_t last_logical_dim, + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t first_logical_dim, const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr) { if (shape_rep == SAME_BOTH_DIMS) { - const size_t current_row = current_offset / last_logical_dim; + const size_t current_row = current_offset / last_logical_dim; const size_t rows_per_tensor = first_logical_dim / num_tensors; return current_row / rows_per_tensor; } else { size_t low = 0, hi = num_tensors; while (low < hi) { - const size_t mid = low + (hi - low) / 2; + const size_t mid = low + (hi - low) / 2; const size_t mid_offset = static_cast(offsets_ptr[mid]); - if (mid_offset <= current_offset) low = mid + 1; - else hi = mid; + if (mid_offset <= current_offset) + low = mid + 1; + else + hi = mid; } return (low == 0) ? 0 : (low - 1); } @@ -75,28 +76,23 @@ struct SmemStorage { }; __global__ void test_offset_caching_kernel( - const int64_t *gmem_offsets, // [num_tensors+1] - const size_t *queries, // [num_queries] — offsets to look up - size_t *result_gmem, // [num_queries] - size_t *result_smem, // [num_queries] - size_t num_tensors, - size_t num_queries, - ShapeRepresentation shape_rep, - size_t first_logical_dim, - size_t last_logical_dim) { - + const int64_t *gmem_offsets, // [num_tensors+1] + const size_t *queries, // [num_queries] — offsets to look up + size_t *result_gmem, // [num_queries] + size_t *result_smem, // [num_queries] + size_t num_tensors, size_t num_queries, ShapeRepresentation shape_rep, size_t first_logical_dim, + size_t last_logical_dim) { extern __shared__ char smem_raw[]; SmemStorage &smem = *reinterpret_cast(smem_raw); // ---- cooperative cp.async load (mirrors production kernel lines 385-397) ---- for (size_t i = threadIdx.x; i <= num_tensors; i += blockDim.x) { uint32_t dst = static_cast(__cvta_generic_to_shared(&smem.offsets[i])); - asm volatile("cp.async.ca.shared.global [%0], [%1], 8;\n" - :: "r"(dst), - "l"(reinterpret_cast(&gmem_offsets[i]))); + asm volatile("cp.async.ca.shared.global [%0], [%1], 8;\n" ::"r"(dst), + "l"(reinterpret_cast(&gmem_offsets[i]))); } asm volatile("cp.async.commit_group;\n" ::); - asm volatile("cp.async.wait_all;\n" ::); + asm volatile("cp.async.wait_all;\n" ::); __syncthreads(); const int64_t *const offsets_smem = smem.offsets; @@ -105,13 +101,11 @@ __global__ void test_offset_caching_kernel( for (size_t q = threadIdx.x; q < num_queries; q += blockDim.x) { size_t offset = queries[q]; - result_gmem[q] = get_current_tensor_id( - shape_rep, num_tensors, offset, - first_logical_dim, last_logical_dim, gmem_offsets); + result_gmem[q] = get_current_tensor_id(shape_rep, num_tensors, offset, first_logical_dim, + last_logical_dim, gmem_offsets); - result_smem[q] = get_current_tensor_id( - shape_rep, num_tensors, offset, - first_logical_dim, last_logical_dim, offsets_smem); + result_smem[q] = get_current_tensor_id(shape_rep, num_tensors, offset, first_logical_dim, + last_logical_dim, offsets_smem); } } @@ -119,22 +113,21 @@ __global__ void test_offset_caching_kernel( // Host-side reference — pure C++, no CUDA // --------------------------------------------------------------------------- -static size_t ref_get_tensor_id(ShapeRepresentation shape_rep, - size_t num_tensors, - size_t current_offset, - size_t first_logical_dim, - size_t last_logical_dim, - const std::vector &offsets) { +static size_t ref_get_tensor_id(ShapeRepresentation shape_rep, size_t num_tensors, + size_t current_offset, size_t first_logical_dim, + size_t last_logical_dim, const std::vector &offsets) { if (shape_rep == SAME_BOTH_DIMS) { - size_t current_row = current_offset / last_logical_dim; + size_t current_row = current_offset / last_logical_dim; size_t rows_per_tensor = first_logical_dim / num_tensors; return current_row / rows_per_tensor; } else { size_t low = 0, hi = num_tensors; while (low < hi) { size_t mid = low + (hi - low) / 2; - if (static_cast(offsets[mid]) <= current_offset) low = mid + 1; - else hi = mid; + if (static_cast(offsets[mid]) <= current_offset) + low = mid + 1; + else + hi = mid; } return (low == 0) ? 0 : (low - 1); } @@ -144,36 +137,31 @@ static size_t ref_get_tensor_id(ShapeRepresentation shape_rep, // Helper: run kernel + compare // --------------------------------------------------------------------------- -static void run_test(ShapeRepresentation shape_rep, - const std::vector &offsets_host, - const std::vector &queries_host, - size_t first_logical_dim, +static void run_test(ShapeRepresentation shape_rep, const std::vector &offsets_host, + const std::vector &queries_host, size_t first_logical_dim, size_t last_logical_dim) { - const size_t num_tensors = offsets_host.size() - 1; // offsets has num_tensors+1 entries const size_t num_queries = queries_host.size(); // --- allocate device memory ----------------------------------------------- int64_t *d_offsets = nullptr; - size_t *d_queries = nullptr, *d_result_gmem = nullptr, *d_result_smem = nullptr; + size_t *d_queries = nullptr, *d_result_gmem = nullptr, *d_result_smem = nullptr; - CUDA_CHECK(cudaMalloc(&d_offsets, (num_tensors + 1) * sizeof(int64_t))); - CUDA_CHECK(cudaMalloc(&d_queries, num_queries * sizeof(size_t))); - CUDA_CHECK(cudaMalloc(&d_result_gmem, num_queries * sizeof(size_t))); - CUDA_CHECK(cudaMalloc(&d_result_smem, num_queries * sizeof(size_t))); + CUDA_CHECK(cudaMalloc(&d_offsets, (num_tensors + 1) * sizeof(int64_t))); + CUDA_CHECK(cudaMalloc(&d_queries, num_queries * sizeof(size_t))); + CUDA_CHECK(cudaMalloc(&d_result_gmem, num_queries * sizeof(size_t))); + CUDA_CHECK(cudaMalloc(&d_result_smem, num_queries * sizeof(size_t))); - CUDA_CHECK(cudaMemcpy(d_offsets, offsets_host.data(), - (num_tensors + 1) * sizeof(int64_t), cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(d_queries, queries_host.data(), - num_queries * sizeof(size_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_offsets, offsets_host.data(), (num_tensors + 1) * sizeof(int64_t), + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_queries, queries_host.data(), num_queries * sizeof(size_t), + cudaMemcpyHostToDevice)); // --- launch --------------------------------------------------------------- int smem_bytes = sizeof(SmemStorage); test_offset_caching_kernel<<<1, 128, smem_bytes>>>( - d_offsets, d_queries, - d_result_gmem, d_result_smem, - num_tensors, num_queries, - shape_rep, first_logical_dim, last_logical_dim); + d_offsets, d_queries, d_result_gmem, d_result_smem, num_tensors, num_queries, shape_rep, + first_logical_dim, last_logical_dim); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); @@ -186,14 +174,11 @@ static void run_test(ShapeRepresentation shape_rep, // --- verify: gmem == smem == host_ref ------------------------------------- for (size_t q = 0; q < num_queries; ++q) { - size_t ref = ref_get_tensor_id(shape_rep, num_tensors, queries_host[q], - first_logical_dim, last_logical_dim, offsets_host); - EXPECT_EQ(h_gmem[q], ref) - << "query=" << queries_host[q] << " gmem result mismatch at q=" << q; - EXPECT_EQ(h_smem[q], ref) - << "query=" << queries_host[q] << " smem result mismatch at q=" << q; - EXPECT_EQ(h_gmem[q], h_smem[q]) - << "query=" << queries_host[q] << " gmem != smem at q=" << q; + size_t ref = ref_get_tensor_id(shape_rep, num_tensors, queries_host[q], first_logical_dim, + last_logical_dim, offsets_host); + EXPECT_EQ(h_gmem[q], ref) << "query=" << queries_host[q] << " gmem result mismatch at q=" << q; + EXPECT_EQ(h_smem[q], ref) << "query=" << queries_host[q] << " smem result mismatch at q=" << q; + EXPECT_EQ(h_gmem[q], h_smem[q]) << "query=" << queries_host[q] << " gmem != smem at q=" << q; } cudaFree(d_offsets); @@ -210,13 +195,12 @@ static void run_test(ShapeRepresentation shape_rep, TEST(OffsetCaching, SameBothDims) { // 4 tensors, each with 128 rows, hidden=256 // offsets (in elements): 0, 128*256, 2*128*256, 3*128*256, 4*128*256 - const size_t hidden = 256; + const size_t hidden = 256; const size_t rows_per = 128; - const size_t N = 4; + const size_t N = 4; std::vector offsets(N + 1); - for (size_t i = 0; i <= N; ++i) - offsets[i] = static_cast(i * rows_per * hidden); + for (size_t i = 0; i <= N; ++i) offsets[i] = static_cast(i * rows_per * hidden); // Query every row's first element std::vector queries; @@ -245,15 +229,15 @@ TEST(OffsetCaching, VaryingFirstDim) { std::vector queries; for (size_t t = 0; t < N; ++t) { size_t start = static_cast(offsets[t]); - size_t end = static_cast(offsets[t + 1]) - 1; - size_t mid = (start + end) / 2; + size_t end = static_cast(offsets[t + 1]) - 1; + size_t mid = (start + end) / 2; queries.push_back(start); queries.push_back(mid); queries.push_back(end); } run_test(VARYING_FIRST_DIM, offsets, queries, - /*first_logical_dim=*/0, // unused for binary search path + /*first_logical_dim=*/0, // unused for binary search path /*last_logical_dim=*/hidden); } @@ -270,8 +254,7 @@ TEST(OffsetCaching, ExactBoundary) { // Query exactly at each offset boundary (should map to the tensor that starts there) std::vector queries; - for (size_t t = 0; t < N; ++t) - queries.push_back(static_cast(offsets[t])); + for (size_t t = 0; t < N; ++t) queries.push_back(static_cast(offsets[t])); run_test(VARYING_FIRST_DIM, offsets, queries, /*first_logical_dim=*/0, @@ -281,10 +264,10 @@ TEST(OffsetCaching, ExactBoundary) { // Case 4: single tensor — degenerate case TEST(OffsetCaching, SingleTensor) { const size_t hidden = 256; - const size_t rows = 512; + const size_t rows = 512; std::vector offsets = {0, static_cast(rows * hidden)}; - std::vector queries = {0, rows * hidden / 2, rows * hidden - 1}; + std::vector queries = {0, rows * hidden / 2, rows * hidden - 1}; run_test(VARYING_FIRST_DIM, offsets, queries, /*first_logical_dim=*/rows, @@ -293,9 +276,9 @@ TEST(OffsetCaching, SingleTensor) { // Case 5: maximum tensors (kMaxTensors=64) TEST(OffsetCaching, MaxTensors) { - const size_t hidden = 128; + const size_t hidden = 128; const size_t rows_each = 128; - const size_t N = 64; + const size_t N = 64; std::vector offsets(N + 1); offsets[0] = 0; @@ -304,8 +287,7 @@ TEST(OffsetCaching, MaxTensors) { // One query per tensor std::vector queries; - for (size_t t = 0; t < N; ++t) - queries.push_back(static_cast(offsets[t])); + for (size_t t = 0; t < N; ++t) queries.push_back(static_cast(offsets[t])); run_test(VARYING_FIRST_DIM, offsets, queries, /*first_logical_dim=*/0, From 1f8dcd062d4a9ddbea0949d68093a134e5c2e952 Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Wed, 15 Apr 2026 09:13:09 -0700 Subject: [PATCH 5/6] [Hadamard] Use cutlass/cute wrappers instead of raw PTX in graph-safe kernel - Replace constexpr kWarpSize=32 with cutlass::NumThreadsPerWarp - Replace asm volatile cp.async.commit_group/wait_all with cute::cp_async_fence()/cp_async_wait<0>() Signed-off-by: Siddhartha Raman S --- ...row_cast_col_hadamard_transform_cast_fusion.cu | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu index 7a23529645..b0b7e1b342 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -533,21 +533,20 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g // the 2nd arrival, firing the barrier. Epilogue warps wait on tma_barrier[0] before reading // offsets_smem/first_dims_smem. // For kEnableRHTColQuant=false: cpasync_barrier[0] is used instead. - constexpr int kWarpSize = 32; - const int local_tidx = threadIdx.x % kWarpSize; + const int local_tidx = threadIdx.x % cutlass::NumThreadsPerWarp; auto async_op = cute::SM80_CP_ASYNC_CACHEALWAYS{}; - for (size_t i = local_tidx; i <= num_tensors; i += kWarpSize) { + for (size_t i = local_tidx; i <= num_tensors; i += cutlass::NumThreadsPerWarp) { async_op.copy(offsets[i], shared_storage.smem_offsets[i]); } - for (size_t i = local_tidx; i < num_tensors; i += kWarpSize) { + for (size_t i = local_tidx; i < num_tensors; i += cutlass::NumThreadsPerWarp) { async_op.copy(first_dims[i], shared_storage.smem_first_dims[i]); } if constexpr (kEnableRHTColQuant) { // Wait for all cp.async offsets/first_dims copies to complete, then provide the // 1st of 2 required arrivals at tma_barrier[0] (with release semantics so consumers // see the smem writes). The 2nd arrival comes from TMA B hardware completion below. - asm volatile("cp.async.commit_group;\n" ::); - asm volatile("cp.async.wait_all;\n" ::); + cute::cp_async_fence(); + cute::cp_async_wait<0>(); __threadfence_block(); if (elect_one_sync()) { transformer_engine::ptx::mbarrier_arrive(&shared_storage.tma_barrier[0]); @@ -555,8 +554,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g } else { // No TMA B in this path. Block until all cp.async ops issued above are complete, then // signal cpasync_barrier[0] so the row quant warp can safely read offsets_smem. - asm volatile("cp.async.commit_group;\n" ::); - asm volatile("cp.async.wait_all;\n" ::); + cute::cp_async_fence(); + cute::cp_async_wait<0>(); __threadfence_block(); if (elect_one_sync()) { transformer_engine::ptx::mbarrier_arrive(&shared_storage.cpasync_barrier[0]); From 7f1980a998dfc65fe2e7207f2442088300f19380 Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Wed, 15 Apr 2026 09:41:18 -0700 Subject: [PATCH 6/6] Add copyright headers to nvfp4 benchmark files Signed-off-by: Siddhartha Raman S --- tests/pytorch/nvfp4/bench_graph_safe_swizzle.py | 4 ++++ tests/pytorch/nvfp4/bench_search.py | 4 ++++ tests/pytorch/nvfp4/bench_structural.py | 4 ++++ tests/pytorch/nvfp4/bench_sweep_swizzle.py | 4 ++++ tests/pytorch/nvfp4/ncu_test.py | 4 ++++ 5 files changed, 20 insertions(+) diff --git a/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py b/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py index 0672afea24..4a910ba333 100644 --- a/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py +++ b/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + import torch import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer diff --git a/tests/pytorch/nvfp4/bench_search.py b/tests/pytorch/nvfp4/bench_search.py index 7f012c5951..7fac1611b9 100644 --- a/tests/pytorch/nvfp4/bench_search.py +++ b/tests/pytorch/nvfp4/bench_search.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer diff --git a/tests/pytorch/nvfp4/bench_structural.py b/tests/pytorch/nvfp4/bench_structural.py index 9d20da3ffd..57cc8fceba 100644 --- a/tests/pytorch/nvfp4/bench_structural.py +++ b/tests/pytorch/nvfp4/bench_structural.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer diff --git a/tests/pytorch/nvfp4/bench_sweep_swizzle.py b/tests/pytorch/nvfp4/bench_sweep_swizzle.py index 9f86afee78..f60629d0e8 100644 --- a/tests/pytorch/nvfp4/bench_sweep_swizzle.py +++ b/tests/pytorch/nvfp4/bench_sweep_swizzle.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer diff --git a/tests/pytorch/nvfp4/ncu_test.py b/tests/pytorch/nvfp4/ncu_test.py index bff6222937..8453e8cdc2 100644 --- a/tests/pytorch/nvfp4/ncu_test.py +++ b/tests/pytorch/nvfp4/ncu_test.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer