diff --git a/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py b/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py new file mode 100644 index 0000000000..4a910ba333 --- /dev/null +++ b/tests/pytorch/nvfp4/bench_graph_safe_swizzle.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer + +M, N = 8192, 7168 # your actual shape +x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") +split_sections = torch.tensor([128] * (M // 128), dtype=torch.int64, device="cuda") + +for optimize_for_gemm in [False, True]: + q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + q.optimize_for_gemm = optimize_for_gemm + + # warmup + for _ in range(10): + tex.group_quantize(x, q, split_sections.shape[0], split_sections) + torch.cuda.synchronize() + + # time + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(100): + tex.group_quantize(x, q, split_sections.shape[0], split_sections) + end.record() + torch.cuda.synchronize() + print(f"optimize_for_gemm={optimize_for_gemm}: {start.elapsed_time(end) / 100 * 1000:.1f} μs") diff --git a/tests/pytorch/nvfp4/bench_search.py b/tests/pytorch/nvfp4/bench_search.py new file mode 100644 index 0000000000..7fac1611b9 --- /dev/null +++ b/tests/pytorch/nvfp4/bench_search.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +import torch +import torch.cuda.nvtx as nvtx + +N = 7168 +num_experts = 64 +ITERS = 50 + +M_VALUES = [8192, 16384, 32768, 65536, 131072] + + +def make_unequal_splits(M, num_experts): + base = M // num_experts + splits = [] + for i in range(num_experts): + if i % 2 == 0: + splits.append(base - 128) + else: + splits.append(base + 128) + # fix rounding so sum == M + diff = M - sum(splits) + splits[-1] += diff + return splits + + +def bench(fn, label, iters=ITERS): + for _ in range(10): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + nvtx.range_push(label) + start.record() + for _ in range(iters): + fn() + end.record() + nvtx.range_pop() + torch.cuda.synchronize() + us = start.elapsed_time(end) / iters * 1000 + print(f" {label}: {us:.1f} us") + return us + + +print(f"N={N}, num_experts={num_experts}") +print("-" * 60) + +for M in M_VALUES: + if M % num_experts != 0 or (M // num_experts) <= 128: + print(f"M={M}: skipped") + continue + + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + label_prefix = f"M{M}" + + print(f"\nM={M}:") + + # --- graph-safe, equal splits (O(1) division) --- + equal_splits = [M // num_experts] * num_experts + equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda") + q_eq = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + q_eq.optimize_for_gemm = False + bench( + lambda: tex.group_quantize(x, q_eq, num_experts, equal_tensor), + f"{label_prefix}_graph_safe_equal_O1", + ) + + # --- graph-safe, unequal splits (binary search) --- + unequal_splits = make_unequal_splits(M, num_experts) + unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda") + q_uneq = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + q_uneq.optimize_for_gemm = False + bench( + lambda: tex.group_quantize(x, q_uneq, num_experts, unequal_tensor), + f"{label_prefix}_graph_safe_unequal_bsearch", + ) + + # --- non-graph-safe (linear scan) --- + q_list = [ + NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + for _ in range(num_experts) + ] + bench( + lambda: tex.split_quantize(x, equal_splits, q_list), + f"{label_prefix}_non_graph_safe_linear", + ) diff --git a/tests/pytorch/nvfp4/bench_structural.py b/tests/pytorch/nvfp4/bench_structural.py new file mode 100644 index 0000000000..57cc8fceba --- /dev/null +++ b/tests/pytorch/nvfp4/bench_structural.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +import torch +import torch.cuda.nvtx as nvtx + +N = 7168 +num_experts = 64 + + +def make_quantizer(): + q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + q.optimize_for_gemm = True + return q + + +def bench(fn, label, iters=100): + for _ in range(10): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + nvtx.range_push(label) + start.record() + for _ in range(iters): + fn() + end.record() + nvtx.range_pop() + torch.cuda.synchronize() + print(f"{label}: {start.elapsed_time(end) / iters * 1000:.1f} us") + + +for M in [16384, 65536, 131072]: + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + + # 1. graph-safe + equal splits -> O(1) division (SAME_BOTH_DIMS) + equal_splits = [M // num_experts] * num_experts + equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda") + q1 = make_quantizer() + bench( + lambda: tex.group_quantize(x, q1, num_experts, equal_tensor), f"[M={M}] graph_safe_equal_O1" + ) + + # 2. graph-safe + unequal splits -> binary search (VARYING_FIRST_DIM) + base = M // num_experts + unequal_splits = [base - 128 if i % 2 == 0 else base + 128 for i in range(num_experts)] + unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda") + q2 = make_quantizer() + bench( + lambda: tex.group_quantize(x, q2, num_experts, unequal_tensor), + f"[M={M}] graph_safe_unequal_binary_search", + ) + + # 3. non-graph-safe + linear scan (GetGroupIdx) + q_list = [ + NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) + for _ in range(num_experts) + ] + bench( + lambda: tex.split_quantize(x, equal_splits, q_list), f"[M={M}] non_graph_safe_linear_scan" + ) + + print() diff --git a/tests/pytorch/nvfp4/bench_sweep_swizzle.py b/tests/pytorch/nvfp4/bench_sweep_swizzle.py new file mode 100644 index 0000000000..f60629d0e8 --- /dev/null +++ b/tests/pytorch/nvfp4/bench_sweep_swizzle.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +import torch +import torch.cuda.nvtx as nvtx + +N = 7168 +num_experts = 64 +ITERS = 50 + +M_VALUES = [8192, 16384, 32768, 65536, 131072] + + +def bench(fn, label, iters=ITERS): + # warmup + for _ in range(10): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + nvtx.range_push(label) + start.record() + for _ in range(iters): + fn() + end.record() + nvtx.range_pop() + torch.cuda.synchronize() + us = start.elapsed_time(end) / iters * 1000 + print(f" {label}: {us:.1f} us") + return us + + +print(f"N={N}, num_experts={num_experts}") +print("-" * 60) + +for M in M_VALUES: + if M % num_experts != 0: + print(f"M={M}: skipped (not divisible by num_experts={num_experts})") + continue + + rows_per_expert = M // num_experts + split_sections = [rows_per_expert] * num_experts + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + + print(f"\nM={M} ({rows_per_expert} rows/expert):") + + label_prefix = f"M{M}" + + # --- graph-safe, swizzle ON --- + q_on = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + ) + q_on.optimize_for_gemm = True + bench( + lambda: tex.group_quantize(x, q_on, num_experts, split_section_tensor), + f"{label_prefix}_graph_safe_swizzle_ON", + ) + + # --- graph-safe, swizzle OFF --- + q_off = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + ) + q_off.optimize_for_gemm = False + bench( + lambda: tex.group_quantize(x, q_off, num_experts, split_section_tensor), + f"{label_prefix}_graph_safe_swizzle_OFF", + ) + + # --- non-graph-safe --- + q_list = [ + NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + ) + for _ in range(num_experts) + ] + bench( + lambda: tex.split_quantize(x, split_sections, q_list), + f"{label_prefix}_non_graph_safe", + ) diff --git a/tests/pytorch/nvfp4/ncu_test.py b/tests/pytorch/nvfp4/ncu_test.py new file mode 100644 index 0000000000..8453e8cdc2 --- /dev/null +++ b/tests/pytorch/nvfp4/ncu_test.py @@ -0,0 +1,23 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +import torch + +M, N, num_experts = 16384, 7168, 64 +x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") +splits = [M // num_experts] * num_experts +split_tensor = torch.tensor(splits, dtype=torch.int64, device="cuda") + +# warmup +q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) +for _ in range(3): + tex.group_quantize(x, q, num_experts, split_tensor) +torch.cuda.synchronize() + +# single measured launch +tex.group_quantize(x, q, num_experts, split_tensor) +torch.cuda.synchronize() diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu index 0c3a5e9299..b0b7e1b342 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -162,6 +162,9 @@ struct SharedStorage { alignas(16) AccumulatorPipelineStorage accumulator; alignas(16) MainloopPipelineStorage mainloop; alignas(16) cute::uint64_t tma_barrier[1]; + // Used only when kEnableRHTColQuant=false: signals cp.async completion of offsets/first_dims + // to the row quant warp without going through tma_barrier[0] (which is never signaled in that path). + alignas(16) cute::uint64_t cpasync_barrier[1]; alignas(16) SchedPipelineStorage sched; alignas(16) SchedThrottlePipelineStorage sched_throttle; alignas(16) int32_t atomic_tile_id[SchedulerPipelineStageCount_]; @@ -169,6 +172,8 @@ struct SharedStorage { alignas(16) float global_d_amax[kMaxTensorsPerKernel]; uint32_t atomic_tile_counter[SchedulerPipelineStageCount_]; uint32_t tmem_base_ptr; + alignas(8) int64_t smem_offsets[kMaxTensorsPerKernel + 1]; + alignas(8) int64_t smem_first_dims[kMaxTensorsPerKernel]; }; // Main RHT GEMM kernel entry -- highly templated for flexible architecture/config support @@ -376,6 +381,13 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g EpilogueUnrollFactor, SchedulerPipelineStageCount>; SharedStorage &shared_storage = *reinterpret_cast(shared_memory); + // offsets_smem and first_dims_smem point to the smem cache populated by the DMA warp via + // LDGSTS (cp.async). tma_barrier[0] is initialized with 2 arrivals: the DMA warp arrives + // after cp.async.wait_all (ensuring offsets/first_dims are in smem), and TMA B hardware + // arrives when the Hadamard matrix load completes. Epilogue warps wait on tma_barrier[0]. + const int64_t *const offsets_smem = shared_storage.smem_offsets; + const int64_t *const first_dims_smem = shared_storage.smem_first_dims; + // Compute the number of tiles in M and N after tiling and assign scheduler uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); uint32_t tiles_in_n = uint32_t(size(ceil_div(sum_token_dims, size<2>(epilogue_tiler)))); @@ -500,7 +512,13 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g cutlass::make_producer_start_state(); if (warp_idx == 2 && elect_one_sync()) { - cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + if constexpr (kEnableRHTColQuant) { + // Two arrivals required: (1) DMA warp after cp.async completes, (2) TMA B hardware completion. + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 2); + } else { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + cute::initialize_barrier(shared_storage.cpasync_barrier[0], /* num_threads */ 1); + } } __syncthreads(); @@ -508,6 +526,42 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g if (is_dma_warp) { // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). cutlass::arch::warpgroup_reg_dealloc<32>(); + + // Use LDGSTS (cp.async) in the DMA warp to load offsets and first_dims into shared memory. + // For kEnableRHTColQuant=true: after cp.async completes, the DMA warp provides the 1st of + // 2 arrivals at tma_barrier[0] (with release semantics). TMA B hardware completion provides + // the 2nd arrival, firing the barrier. Epilogue warps wait on tma_barrier[0] before reading + // offsets_smem/first_dims_smem. + // For kEnableRHTColQuant=false: cpasync_barrier[0] is used instead. + const int local_tidx = threadIdx.x % cutlass::NumThreadsPerWarp; + auto async_op = cute::SM80_CP_ASYNC_CACHEALWAYS{}; + for (size_t i = local_tidx; i <= num_tensors; i += cutlass::NumThreadsPerWarp) { + async_op.copy(offsets[i], shared_storage.smem_offsets[i]); + } + for (size_t i = local_tidx; i < num_tensors; i += cutlass::NumThreadsPerWarp) { + async_op.copy(first_dims[i], shared_storage.smem_first_dims[i]); + } + if constexpr (kEnableRHTColQuant) { + // Wait for all cp.async offsets/first_dims copies to complete, then provide the + // 1st of 2 required arrivals at tma_barrier[0] (with release semantics so consumers + // see the smem writes). The 2nd arrival comes from TMA B hardware completion below. + cute::cp_async_fence(); + cute::cp_async_wait<0>(); + __threadfence_block(); + if (elect_one_sync()) { + transformer_engine::ptx::mbarrier_arrive(&shared_storage.tma_barrier[0]); + } + } else { + // No TMA B in this path. Block until all cp.async ops issued above are complete, then + // signal cpasync_barrier[0] so the row quant warp can safely read offsets_smem. + cute::cp_async_fence(); + cute::cp_async_wait<0>(); + __threadfence_block(); + if (elect_one_sync()) { + transformer_engine::ptx::mbarrier_arrive(&shared_storage.cpasync_barrier[0]); + } + } + // Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory. Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); @@ -704,14 +758,16 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; } + // tma_barrier[0] completes after both B tensor TMA and offsets cp.async finish. + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); // TODO(zhongbo): double check the logic here int group_idx = get_current_tensor_id( shape_rep, num_tensors, (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, - packed_N, M, offsets); + packed_N, M, offsets_smem); // Determine quantization scale factor layouts/output splits for this group TSFDLayout sfd_layout; - int cur_N = static_cast(first_dims[group_idx]); + int cur_N = static_cast(first_dims_smem[group_idx]); if constexpr (kEnableSwizzleSFOutput) { sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); } else { @@ -772,7 +828,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g // TODO(zhongbo): double check the logic here int cur_group_idx = get_current_tensor_id( - shape_rep, num_tensors, global_tile_n_offset * M, packed_N, M, offsets); + shape_rep, num_tensors, global_tile_n_offset * M, packed_N, M, offsets_smem); if (cur_group_idx != group_idx) { group_idx = cur_group_idx; @@ -786,7 +842,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g global_decode_scale = 1.0f / global_encode_scale; global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; // TODO(zhongbo): double check the logic here - cur_N = first_dims[group_idx]; + cur_N = first_dims_smem[group_idx]; if constexpr (kEnableSwizzleSFOutput) { sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); @@ -999,9 +1055,17 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g constexpr int row_quant_barrier_id = 2; cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); + // Wait until offsets/first_dims cp.async copies are visible in smem. + // When kEnableRHTColQuant=true, tma_barrier[0] covers both B TMA and cp.async. + // When kEnableRHTColQuant=false, cpasync_barrier[0] covers cp.async only. + if constexpr (kEnableRHTColQuant) { + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*phase_bit*/); + } else { + cute::wait_barrier(shared_storage.cpasync_barrier[0], 0 /*phase_bit*/); + } int group_idx = get_current_tensor_id( shape_rep, num_tensors, (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, - packed_N, M, offsets); + packed_N, M, offsets_smem); float a_global_amax_val = shared_storage.global_a_amax[group_idx]; // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} static constexpr float fp4_max = 6.0f; @@ -1023,7 +1087,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); int cur_group_idx = get_current_tensor_id( - shape_rep, num_tensors, global_tile_n_offset * M, packed_N, M, offsets); + shape_rep, num_tensors, global_tile_n_offset * M, packed_N, M, offsets_smem); if (cur_group_idx != group_idx) { group_idx = cur_group_idx; a_global_amax_val = shared_storage.global_a_amax[group_idx]; @@ -1259,10 +1323,15 @@ void group_row_col_rht_gemm_ntt_w_sfc_graph_safe( static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes static int constexpr kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes; - static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + - SchedulerPipelineBytes + TmemBasePtrsBytes + - TmemDeallocBytes + BTensorBytes + - AccPipelineBytes; // Reserve for barriers and other uses + // smem_offsets: (kMaxTensorsPerKernel+1) x int64, smem_first_dims: kMaxTensorsPerKernel x int64 + // Note: tma_barrier[1] and cpasync_barrier[1] (8 bytes each) are not counted here; their impact + // on kMaxStages is negligible (<0.01% of the 232KB budget). + static int constexpr kSmemOffsetCacheBytes = + sizeof(int64_t) * (kMaxTensorsPerKernel + 1) + sizeof(int64_t) * kMaxTensorsPerKernel; + static int constexpr kReservedBytes = + SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + SchedulerPipelineBytes + + TmemBasePtrsBytes + TmemDeallocBytes + BTensorBytes + AccPipelineBytes + + kSmemOffsetCacheBytes; // Reserve for barriers and other uses static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; auto sP = Int{}; // SMEM pipelines diff --git a/transformer_engine/common/hadamard_transform/test_offset_caching.cu b/transformer_engine/common/hadamard_transform/test_offset_caching.cu new file mode 100644 index 0000000000..d6db4bb990 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/test_offset_caching.cu @@ -0,0 +1,295 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +// Tests that caching offsets/first_dims from gmem into smem via cp.async +// produces identical results to reading directly from gmem in +// get_current_tensor_id(). + +#include +#include + +#include +#include + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +#define CUDA_CHECK(expr) \ + do { \ + cudaError_t _e = (expr); \ + ASSERT_EQ(_e, cudaSuccess) << cudaGetErrorString(_e); \ + } while (0) + +// --------------------------------------------------------------------------- +// Device-side implementation (mirrors the file under test) +// --------------------------------------------------------------------------- + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +// Exact copy of get_current_tensor_id from the file under test +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { + if (shape_rep == SAME_BOTH_DIMS) { + const size_t current_row = current_offset / last_logical_dim; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + size_t low = 0, hi = num_tensors; + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + if (mid_offset <= current_offset) + low = mid + 1; + else + hi = mid; + } + return (low == 0) ? 0 : (low - 1); + } +} + +// --------------------------------------------------------------------------- +// Test kernel +// +// For every query offset in `queries[]`: +// result_gmem[i] = get_current_tensor_id(..., gmem offsets pointer) +// result_smem[i] = get_current_tensor_id(..., smem-cached offsets pointer) +// +// The smem path uses the same cp.async + wait_all + __syncthreads() sequence +// as the production kernel. +// --------------------------------------------------------------------------- + +constexpr int kMaxTensors = 64; + +struct SmemStorage { + int64_t offsets[kMaxTensors + 1]; +}; + +__global__ void test_offset_caching_kernel( + const int64_t *gmem_offsets, // [num_tensors+1] + const size_t *queries, // [num_queries] — offsets to look up + size_t *result_gmem, // [num_queries] + size_t *result_smem, // [num_queries] + size_t num_tensors, size_t num_queries, ShapeRepresentation shape_rep, size_t first_logical_dim, + size_t last_logical_dim) { + extern __shared__ char smem_raw[]; + SmemStorage &smem = *reinterpret_cast(smem_raw); + + // ---- cooperative cp.async load (mirrors production kernel lines 385-397) ---- + for (size_t i = threadIdx.x; i <= num_tensors; i += blockDim.x) { + uint32_t dst = static_cast(__cvta_generic_to_shared(&smem.offsets[i])); + asm volatile("cp.async.ca.shared.global [%0], [%1], 8;\n" ::"r"(dst), + "l"(reinterpret_cast(&gmem_offsets[i]))); + } + asm volatile("cp.async.commit_group;\n" ::); + asm volatile("cp.async.wait_all;\n" ::); + __syncthreads(); + + const int64_t *const offsets_smem = smem.offsets; + + // ---- each thread handles one query ---------------------------------------- + for (size_t q = threadIdx.x; q < num_queries; q += blockDim.x) { + size_t offset = queries[q]; + + result_gmem[q] = get_current_tensor_id(shape_rep, num_tensors, offset, first_logical_dim, + last_logical_dim, gmem_offsets); + + result_smem[q] = get_current_tensor_id(shape_rep, num_tensors, offset, first_logical_dim, + last_logical_dim, offsets_smem); + } +} + +// --------------------------------------------------------------------------- +// Host-side reference — pure C++, no CUDA +// --------------------------------------------------------------------------- + +static size_t ref_get_tensor_id(ShapeRepresentation shape_rep, size_t num_tensors, + size_t current_offset, size_t first_logical_dim, + size_t last_logical_dim, const std::vector &offsets) { + if (shape_rep == SAME_BOTH_DIMS) { + size_t current_row = current_offset / last_logical_dim; + size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + size_t low = 0, hi = num_tensors; + while (low < hi) { + size_t mid = low + (hi - low) / 2; + if (static_cast(offsets[mid]) <= current_offset) + low = mid + 1; + else + hi = mid; + } + return (low == 0) ? 0 : (low - 1); + } +} + +// --------------------------------------------------------------------------- +// Helper: run kernel + compare +// --------------------------------------------------------------------------- + +static void run_test(ShapeRepresentation shape_rep, const std::vector &offsets_host, + const std::vector &queries_host, size_t first_logical_dim, + size_t last_logical_dim) { + const size_t num_tensors = offsets_host.size() - 1; // offsets has num_tensors+1 entries + const size_t num_queries = queries_host.size(); + + // --- allocate device memory ----------------------------------------------- + int64_t *d_offsets = nullptr; + size_t *d_queries = nullptr, *d_result_gmem = nullptr, *d_result_smem = nullptr; + + CUDA_CHECK(cudaMalloc(&d_offsets, (num_tensors + 1) * sizeof(int64_t))); + CUDA_CHECK(cudaMalloc(&d_queries, num_queries * sizeof(size_t))); + CUDA_CHECK(cudaMalloc(&d_result_gmem, num_queries * sizeof(size_t))); + CUDA_CHECK(cudaMalloc(&d_result_smem, num_queries * sizeof(size_t))); + + CUDA_CHECK(cudaMemcpy(d_offsets, offsets_host.data(), (num_tensors + 1) * sizeof(int64_t), + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_queries, queries_host.data(), num_queries * sizeof(size_t), + cudaMemcpyHostToDevice)); + + // --- launch --------------------------------------------------------------- + int smem_bytes = sizeof(SmemStorage); + test_offset_caching_kernel<<<1, 128, smem_bytes>>>( + d_offsets, d_queries, d_result_gmem, d_result_smem, num_tensors, num_queries, shape_rep, + first_logical_dim, last_logical_dim); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + // --- copy results back ---------------------------------------------------- + std::vector h_gmem(num_queries), h_smem(num_queries); + CUDA_CHECK(cudaMemcpy(h_gmem.data(), d_result_gmem, num_queries * sizeof(size_t), + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_smem.data(), d_result_smem, num_queries * sizeof(size_t), + cudaMemcpyDeviceToHost)); + + // --- verify: gmem == smem == host_ref ------------------------------------- + for (size_t q = 0; q < num_queries; ++q) { + size_t ref = ref_get_tensor_id(shape_rep, num_tensors, queries_host[q], first_logical_dim, + last_logical_dim, offsets_host); + EXPECT_EQ(h_gmem[q], ref) << "query=" << queries_host[q] << " gmem result mismatch at q=" << q; + EXPECT_EQ(h_smem[q], ref) << "query=" << queries_host[q] << " smem result mismatch at q=" << q; + EXPECT_EQ(h_gmem[q], h_smem[q]) << "query=" << queries_host[q] << " gmem != smem at q=" << q; + } + + cudaFree(d_offsets); + cudaFree(d_queries); + cudaFree(d_result_gmem); + cudaFree(d_result_smem); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +// Case 1: SAME_BOTH_DIMS — 4 tensors, equal rows each +TEST(OffsetCaching, SameBothDims) { + // 4 tensors, each with 128 rows, hidden=256 + // offsets (in elements): 0, 128*256, 2*128*256, 3*128*256, 4*128*256 + const size_t hidden = 256; + const size_t rows_per = 128; + const size_t N = 4; + + std::vector offsets(N + 1); + for (size_t i = 0; i <= N; ++i) offsets[i] = static_cast(i * rows_per * hidden); + + // Query every row's first element + std::vector queries; + for (size_t t = 0; t < N; ++t) + for (size_t r = 0; r < rows_per; ++r) + queries.push_back(static_cast(offsets[t]) + r * hidden); + + run_test(SAME_BOTH_DIMS, offsets, queries, + /*first_logical_dim=*/N * rows_per, + /*last_logical_dim=*/hidden); +} + +// Case 2: VARYING_FIRST_DIM — tensors with different numbers of rows +TEST(OffsetCaching, VaryingFirstDim) { + // 3 tensors with row counts 128, 256, 192; hidden=512 + const size_t hidden = 512; + const std::vector row_counts = {128, 256, 192}; + const size_t N = row_counts.size(); + + std::vector offsets(N + 1); + offsets[0] = 0; + for (size_t i = 0; i < N; ++i) + offsets[i + 1] = offsets[i] + static_cast(row_counts[i] * hidden); + + // Query: first element, last element, and middle element of each tensor + std::vector queries; + for (size_t t = 0; t < N; ++t) { + size_t start = static_cast(offsets[t]); + size_t end = static_cast(offsets[t + 1]) - 1; + size_t mid = (start + end) / 2; + queries.push_back(start); + queries.push_back(mid); + queries.push_back(end); + } + + run_test(VARYING_FIRST_DIM, offsets, queries, + /*first_logical_dim=*/0, // unused for binary search path + /*last_logical_dim=*/hidden); +} + +// Case 3: boundary — query lands exactly on an offset boundary +TEST(OffsetCaching, ExactBoundary) { + const size_t hidden = 128; + const std::vector row_counts = {128, 128, 256, 64}; + const size_t N = row_counts.size(); + + std::vector offsets(N + 1); + offsets[0] = 0; + for (size_t i = 0; i < N; ++i) + offsets[i + 1] = offsets[i] + static_cast(row_counts[i] * hidden); + + // Query exactly at each offset boundary (should map to the tensor that starts there) + std::vector queries; + for (size_t t = 0; t < N; ++t) queries.push_back(static_cast(offsets[t])); + + run_test(VARYING_FIRST_DIM, offsets, queries, + /*first_logical_dim=*/0, + /*last_logical_dim=*/hidden); +} + +// Case 4: single tensor — degenerate case +TEST(OffsetCaching, SingleTensor) { + const size_t hidden = 256; + const size_t rows = 512; + + std::vector offsets = {0, static_cast(rows * hidden)}; + std::vector queries = {0, rows * hidden / 2, rows * hidden - 1}; + + run_test(VARYING_FIRST_DIM, offsets, queries, + /*first_logical_dim=*/rows, + /*last_logical_dim=*/hidden); +} + +// Case 5: maximum tensors (kMaxTensors=64) +TEST(OffsetCaching, MaxTensors) { + const size_t hidden = 128; + const size_t rows_each = 128; + const size_t N = 64; + + std::vector offsets(N + 1); + offsets[0] = 0; + for (size_t i = 0; i < N; ++i) + offsets[i + 1] = offsets[i] + static_cast(rows_each * hidden); + + // One query per tensor + std::vector queries; + for (size_t t = 0; t < N; ++t) queries.push_back(static_cast(offsets[t])); + + run_test(VARYING_FIRST_DIM, offsets, queries, + /*first_logical_dim=*/0, + /*last_logical_dim=*/hidden); +}