Conversation
d241123 to
2736535
Compare
Greptile SummaryThis PR caches The memory ordering is correct — Confidence Score: 5/5Safe to merge; the barrier logic and memory ordering are correct and the only finding is a P2 missed-optimization suggestion. All barrier initialization, arrive, and wait sequences are correct for both the kEnableRHTColQuant=true (2-arrival tma_barrier) and false (cpasync_barrier) paths. The __syncthreads after barrier init, the cp_async_wait+threadfence_block release pattern, and the constexpr-guarded wait selection are all sound. The sole inline comment is a P2 optimization opportunity (one per-tile gmem read of offsets[group_idx] that could use offsets_smem), not a correctness defect. graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu — verify the per-tile offsets[group_idx] access at line 869 is intentionally left as gmem or should be switched to offsets_smem. Important Files Changed
Sequence DiagramsequenceDiagram
participant DMA as DMA Warp
participant ColQ as Col Quant Warp
participant RowQ as Row Quant Warp
participant TMAB as TMA B Hardware
Note over DMA: cp.async offsets[] → smem_offsets<br/>cp.async first_dims[] → smem_first_dims
DMA->>DMA: cp_async_fence() + cp_async_wait<0>() + __threadfence_block()
alt kEnableRHTColQuant=true
DMA->>DMA: mbarrier_arrive(tma_barrier[0]) ①
DMA->>TMAB: launch TMA B load
TMAB-->>DMA: (async)
TMAB->>ColQ: mbarrier_arrive(tma_barrier[0]) ②
Note over ColQ: wait_barrier(tma_barrier[0], phase=0)<br/>fires after both ① and ②
Note over RowQ: wait_barrier(tma_barrier[0], phase=0)<br/>fires after both ① and ②
ColQ->>ColQ: read offsets_smem / first_dims_smem
RowQ->>RowQ: read offsets_smem / first_dims_smem
else kEnableRHTColQuant=false
DMA->>DMA: mbarrier_arrive(cpasync_barrier[0]) ①
DMA->>TMAB: launch TMA B load
TMAB->>TMAB: mbarrier_arrive(tma_barrier[0]) - no waiter
Note over RowQ: wait_barrier(cpasync_barrier[0], phase=0)<br/>fires after ①
RowQ->>RowQ: read offsets_smem / first_dims_smem
end
Reviews (5): Last reviewed commit: "Add copyright headers to nvfp4 benchmark..." | Re-trigger Greptile |
| import transformer_engine.pytorch as te | ||
| import transformer_engine_torch as tex | ||
| from transformer_engine.pytorch import NVFP4Quantizer | ||
| import torch | ||
| import torch.cuda.nvtx as nvtx | ||
|
|
||
| N = 7168 | ||
| num_experts = 64 | ||
|
|
||
|
|
||
| def make_quantizer(): | ||
| q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) | ||
| q.optimize_for_gemm = True | ||
| return q | ||
|
|
||
|
|
||
| def bench(fn, label, iters=100): | ||
| for _ in range(10): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| start = torch.cuda.Event(enable_timing=True) | ||
| end = torch.cuda.Event(enable_timing=True) | ||
| nvtx.range_push(label) | ||
| start.record() | ||
| for _ in range(iters): | ||
| fn() | ||
| end.record() | ||
| nvtx.range_pop() | ||
| torch.cuda.synchronize() | ||
| print(f"{label}: {start.elapsed_time(end) / iters * 1000:.1f} us") | ||
|
|
||
|
|
||
| for M in [16384, 65536, 131072]: | ||
| x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") | ||
|
|
||
| # 1. graph-safe + equal splits -> O(1) division (SAME_BOTH_DIMS) | ||
| equal_splits = [M // num_experts] * num_experts | ||
| equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda") | ||
| q1 = make_quantizer() | ||
| bench( | ||
| lambda: tex.group_quantize(x, q1, num_experts, equal_tensor), f"[M={M}] graph_safe_equal_O1" | ||
| ) | ||
|
|
||
| # 2. graph-safe + unequal splits -> binary search (VARYING_FIRST_DIM) | ||
| base = M // num_experts | ||
| unequal_splits = [base - 128 if i % 2 == 0 else base + 128 for i in range(num_experts)] | ||
| unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda") | ||
| q2 = make_quantizer() | ||
| bench( | ||
| lambda: tex.group_quantize(x, q2, num_experts, unequal_tensor), | ||
| f"[M={M}] graph_safe_unequal_binary_search", | ||
| ) | ||
|
|
||
| # 3. non-graph-safe + linear scan (GetGroupIdx) | ||
| q_list = [ | ||
| NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True) | ||
| for _ in range(num_experts) | ||
| ] | ||
| bench( | ||
| lambda: tex.split_quantize(x, equal_splits, q_list), f"[M={M}] non_graph_safe_linear_scan" | ||
| ) | ||
|
|
||
| print() |
There was a problem hiding this comment.
Module-level GPU code will execute on pytest import
All five new scripts (bench_structural.py, bench_sweep_swizzle.py, bench_search.py, bench_graph_safe_swizzle.py, ncu_test.py) contain GPU kernel launches at module scope. When pytest discovers files in tests/pytorch/nvfp4/, it imports each one to collect tests; the imports execute the benchmarks immediately — potentially hanging or crashing CI on machines without the required GPU or package.
Wrap the benchmark body in a if __name__ == "__main__": guard on all five files, e.g.:
if __name__ == "__main__":
for M in [16384, 65536, 131072]:
...| // the 2nd arrival, firing the barrier. Epilogue warps wait on tma_barrier[0] before reading | ||
| // offsets_smem/first_dims_smem. | ||
| // For kEnableRHTColQuant=false: cpasync_barrier[0] is used instead. | ||
| constexpr int kWarpSize = 32; |
There was a problem hiding this comment.
use NumThreadsPerWarp in cutlass header. Do not need redundant definition.
| // No TMA B in this path. Block until all cp.async ops issued above are complete, then | ||
| // signal cpasync_barrier[0] so the row quant warp can safely read offsets_smem. | ||
| asm volatile("cp.async.commit_group;\n" ::); | ||
| asm volatile("cp.async.wait_all;\n" ::); |
There was a problem hiding this comment.
use cp_async_fence() and cp_async_wait().
…fe kernel Use LDGSTS (cp.async) in the DMA warp to load offsets and first_dims arrays from global memory into shared memory, replacing direct global reads in the epilogue/row-quant warps. Adds cpasync_barrier for the non-RHTColQuant path and smem_offsets/smem_first_dims fields to SharedStorage. Includes offset-caching unit tests and swizzle benchmark. Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
- bench_search.py, bench_structural.py, bench_sweep_swizzle.py, ncu_test.py Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
… kernel - Replace constexpr kWarpSize=32 with cutlass::NumThreadsPerWarp - Replace asm volatile cp.async.commit_group/wait_all with cute::cp_async_fence()/cp_async_wait<0>() Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
7f75817 to
1f8dcd0
Compare
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| // Tests that caching offsets/first_dims from gmem into smem via cp.async | ||
| // produces identical results to reading directly from gmem in | ||
| // get_current_tensor_id(). | ||
|
|
||
| #include <cuda_runtime.h> | ||
| #include <gtest/gtest.h> | ||
|
|
||
| #include <cstdint> | ||
| #include <vector> | ||
|
|
There was a problem hiding this comment.
Test file never runs — missing build system integration
test_offset_caching.cu uses GTest (TEST(...) fixtures) but has no main() function and is not referenced in any CMakeLists.txt (there is no CMakeLists.txt in transformer_engine/common/hadamard_transform/, and a project-wide search finds no reference to this file). As a result this test is never compiled or executed in CI, so the barrier correctness it is meant to verify is untested.
To wire it in, either add a CMakeLists.txt entry that links against gtest_main, or add a main() directly:
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Description
Cache offsets and first_dims in shared memory for graph-safe kernel, helps with reduction of memory bytes read
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: