Skip to content

SMEM offset caching RHT#2882

Open
sraman-rgb wants to merge 6 commits intoNVIDIA:mainfrom
sraman-rgb:feat/smem-offset-caching-hadamard
Open

SMEM offset caching RHT#2882
sraman-rgb wants to merge 6 commits intoNVIDIA:mainfrom
sraman-rgb:feat/smem-offset-caching-hadamard

Conversation

@sraman-rgb
Copy link
Copy Markdown

Description

Cache offsets and first_dims in shared memory for graph-safe kernel, helps with reduction of memory bytes read

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sraman-rgb sraman-rgb force-pushed the feat/smem-offset-caching-hadamard branch from d241123 to 2736535 Compare April 15, 2026 01:11
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 15, 2026

Greptile Summary

This PR caches offsets and first_dims from global memory into shared memory via cp.async (LDGSTS) in the DMA warp, then uses mbarrier synchronization to ensure the data is visible before epilogue warps read it. The barrier strategy is well-designed: for kEnableRHTColQuant=true, tma_barrier[0] is promoted from 1 to 2 required arrivals (DMA cp.async completion + TMA B hardware) so both data-readiness conditions share a single wait; for kEnableRHTColQuant=false, a new cpasync_barrier[0] carries just the cp.async signal to the row-quant warp without interfering with TMA B synchronization.

The memory ordering is correct — cp_async_wait<0>() + __threadfence_block() + mbarrier_arrive (release) paired with mbarrier_wait (acquire) in consumers is the standard LDGSTS→mbarrier pattern. One remaining gmem read that the smem cache could eliminate is noted inline.

Confidence Score: 5/5

Safe to merge; the barrier logic and memory ordering are correct and the only finding is a P2 missed-optimization suggestion.

All barrier initialization, arrive, and wait sequences are correct for both the kEnableRHTColQuant=true (2-arrival tma_barrier) and false (cpasync_barrier) paths. The __syncthreads after barrier init, the cp_async_wait+threadfence_block release pattern, and the constexpr-guarded wait selection are all sound. The sole inline comment is a P2 optimization opportunity (one per-tile gmem read of offsets[group_idx] that could use offsets_smem), not a correctness defect.

graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu — verify the per-tile offsets[group_idx] access at line 869 is intentionally left as gmem or should be switched to offsets_smem.

Important Files Changed

Filename Overview
transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu Adds SMEM caching of offsets/first_dims via cp.async with correct barrier synchronization; one per-tile gmem read of offsets[group_idx] at line 869 remains in the hot path despite the smem cache being available.
transformer_engine/common/hadamard_transform/test_offset_caching.cu New GTest unit test verifying cp.async-cached offsets match gmem results; build system integration noted as pre-existing concern.
tests/pytorch/nvfp4/bench_graph_safe_swizzle.py New benchmark script for graph-safe swizzle; module-level GPU execution noted as pre-existing concern.
tests/pytorch/nvfp4/bench_structural.py Structural benchmark covering O(1), binary-search, and linear-scan paths; module-level GPU execution noted as pre-existing concern.
tests/pytorch/nvfp4/bench_search.py New benchmark comparing graph-safe equal/unequal splits and non-graph-safe quantization; module-level GPU execution noted as pre-existing concern.
tests/pytorch/nvfp4/bench_sweep_swizzle.py Sweeps swizzle ON/OFF across M values; module-level GPU execution noted as pre-existing concern.
tests/pytorch/nvfp4/ncu_test.py Minimal NCU profiling harness with warmup + single measured launch; module-level GPU execution noted as pre-existing concern.

Sequence Diagram

sequenceDiagram
    participant DMA as DMA Warp
    participant ColQ as Col Quant Warp
    participant RowQ as Row Quant Warp
    participant TMAB as TMA B Hardware

    Note over DMA: cp.async offsets[] → smem_offsets<br/>cp.async first_dims[] → smem_first_dims
    DMA->>DMA: cp_async_fence() + cp_async_wait<0>() + __threadfence_block()

    alt kEnableRHTColQuant=true
        DMA->>DMA: mbarrier_arrive(tma_barrier[0])  ①
        DMA->>TMAB: launch TMA B load
        TMAB-->>DMA: (async)
        TMAB->>ColQ: mbarrier_arrive(tma_barrier[0])  ②
        Note over ColQ: wait_barrier(tma_barrier[0], phase=0)<br/>fires after both ① and ②
        Note over RowQ: wait_barrier(tma_barrier[0], phase=0)<br/>fires after both ① and ②
        ColQ->>ColQ: read offsets_smem / first_dims_smem
        RowQ->>RowQ: read offsets_smem / first_dims_smem
    else kEnableRHTColQuant=false
        DMA->>DMA: mbarrier_arrive(cpasync_barrier[0])  ①
        DMA->>TMAB: launch TMA B load
        TMAB->>TMAB: mbarrier_arrive(tma_barrier[0]) - no waiter
        Note over RowQ: wait_barrier(cpasync_barrier[0], phase=0)<br/>fires after ①
        RowQ->>RowQ: read offsets_smem / first_dims_smem
    end
Loading

Reviews (5): Last reviewed commit: "Add copyright headers to nvfp4 benchmark..." | Re-trigger Greptile

Comment on lines +1 to +63
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
import torch
import torch.cuda.nvtx as nvtx

N = 7168
num_experts = 64


def make_quantizer():
q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True)
q.optimize_for_gemm = True
return q


def bench(fn, label, iters=100):
for _ in range(10):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
nvtx.range_push(label)
start.record()
for _ in range(iters):
fn()
end.record()
nvtx.range_pop()
torch.cuda.synchronize()
print(f"{label}: {start.elapsed_time(end) / iters * 1000:.1f} us")


for M in [16384, 65536, 131072]:
x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")

# 1. graph-safe + equal splits -> O(1) division (SAME_BOTH_DIMS)
equal_splits = [M // num_experts] * num_experts
equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda")
q1 = make_quantizer()
bench(
lambda: tex.group_quantize(x, q1, num_experts, equal_tensor), f"[M={M}] graph_safe_equal_O1"
)

# 2. graph-safe + unequal splits -> binary search (VARYING_FIRST_DIM)
base = M // num_experts
unequal_splits = [base - 128 if i % 2 == 0 else base + 128 for i in range(num_experts)]
unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda")
q2 = make_quantizer()
bench(
lambda: tex.group_quantize(x, q2, num_experts, unequal_tensor),
f"[M={M}] graph_safe_unequal_binary_search",
)

# 3. non-graph-safe + linear scan (GetGroupIdx)
q_list = [
NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True)
for _ in range(num_experts)
]
bench(
lambda: tex.split_quantize(x, equal_splits, q_list), f"[M={M}] non_graph_safe_linear_scan"
)

print()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Module-level GPU code will execute on pytest import

All five new scripts (bench_structural.py, bench_sweep_swizzle.py, bench_search.py, bench_graph_safe_swizzle.py, ncu_test.py) contain GPU kernel launches at module scope. When pytest discovers files in tests/pytorch/nvfp4/, it imports each one to collect tests; the imports execute the benchmarks immediately — potentially hanging or crashing CI on machines without the required GPU or package.

Wrap the benchmark body in a if __name__ == "__main__": guard on all five files, e.g.:

if __name__ == "__main__":
    for M in [16384, 65536, 131072]:
        ...

// the 2nd arrival, firing the barrier. Epilogue warps wait on tma_barrier[0] before reading
// offsets_smem/first_dims_smem.
// For kEnableRHTColQuant=false: cpasync_barrier[0] is used instead.
constexpr int kWarpSize = 32;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use NumThreadsPerWarp in cutlass header. Do not need redundant definition.

// No TMA B in this path. Block until all cp.async ops issued above are complete, then
// signal cpasync_barrier[0] so the row quant warp can safely read offsets_smem.
asm volatile("cp.async.commit_group;\n" ::);
asm volatile("cp.async.wait_all;\n" ::);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use cp_async_fence() and cp_async_wait().

sraman-rgb and others added 5 commits April 15, 2026 09:15
…fe kernel

Use LDGSTS (cp.async) in the DMA warp to load offsets and first_dims arrays
from global memory into shared memory, replacing direct global reads in the
epilogue/row-quant warps. Adds cpasync_barrier for the non-RHTColQuant path
and smem_offsets/smem_first_dims fields to SharedStorage. Includes
offset-caching unit tests and swizzle benchmark.

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
- bench_search.py, bench_structural.py, bench_sweep_swizzle.py, ncu_test.py

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
… kernel

- Replace constexpr kWarpSize=32 with cutlass::NumThreadsPerWarp
- Replace asm volatile cp.async.commit_group/wait_all with cute::cp_async_fence()/cp_async_wait<0>()

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
@sraman-rgb sraman-rgb force-pushed the feat/smem-offset-caching-hadamard branch from 7f75817 to 1f8dcd0 Compare April 15, 2026 16:15
Comment on lines +1 to +16
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

// Tests that caching offsets/first_dims from gmem into smem via cp.async
// produces identical results to reading directly from gmem in
// get_current_tensor_id().

#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <cstdint>
#include <vector>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Test file never runs — missing build system integration

test_offset_caching.cu uses GTest (TEST(...) fixtures) but has no main() function and is not referenced in any CMakeLists.txt (there is no CMakeLists.txt in transformer_engine/common/hadamard_transform/, and a project-wide search finds no reference to this file). As a result this test is never compiled or executed in CI, so the barrier correctness it is meant to verify is untested.

To wire it in, either add a CMakeLists.txt entry that links against gtest_main, or add a main() directly:

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
}

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants