Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114
Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114vthumbe1503 wants to merge 76 commits into
Conversation
Route grouped Float8CurrentScalingQuantizer through the existing grouped quantize entry point, prepare per-group current-scaling metadata with existing amax/scale helpers, and add focused tests plus a GB200 bandwidth benchmark. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_5507e814ee50f9ff304a4ce708d19768 Orchestra-Run: run_516e1e26891f4ce7d4cde07147c10862
Use wider vectorized grouped FP8 cast-transpose tiles and vectorized masked stores for rowwise and columnwise outputs. Capture all benchmark modes in a single post-warmup profiler range. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_3d6e33eab11e293d72eb4394bad76a81 Orchestra-Run: run_a6e2c31d5fdf850594f71438e53148da
Route non-MXFP8 grouped-linear bias backward through group_quantize plus grouped dbias while keeping MXFP8 bgrad_group_quantize fusion intact. Add focused zero-row grouped FP8 coverage and a current-scaling GroupedLinear bias-backward regression. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_ab566800d87047635cd27f9e64661abe Orchestra-Run: run_5f9bfef17ccd854232c54d56268ef9e8
Use packed FP8 conversion and reduce columnwise transpose staging register and synchronization overhead in group_cast_fp8_kernel. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_7a830e018ceac8de0018280bd0740a54 Orchestra-Run: run_d2f1df4ffc2265d9cfa5ed01028ee476
Match the grouped FP8 conversion helper's element-count template parameter to Vec's uint32_t parameter so rowwise, columnwise, and activation instantiations can build. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_30c4b6ddb896e5ea3ca5b54731d2c819 Orchestra-Run: run_e95cdbb445943304622b95736f0eca49
Use cached grouped offsets to avoid launching FP8 quantization over unused overallocated rows, permit larger grouped backing buffers when split metadata is present, and tighten full-tile vector paths in the grouped FP8 cast kernel. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_c5db93823dc101838cb1323e283cd6e9 Orchestra-Run: run_063e2e4c724e132612aa5597d6765c9b
Use the FP8 grouped output logical shape when computing the tensor-scaling launch grid so overallocated buffers with active metadata avoid empty tail-row launches while preserving the allocated-shape fallback. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_b4abb47c990404d73142342a19996a3f Orchestra-Run: run_8f09e7b9d7af9754ef505f2e2ce3cf90
Use larger grouped FP8 tiles with 8-warp CTAs and 16-row columnwise store fragments. Treat uniform overallocated FP8 grouped outputs as same-shape wrappers during output reuse so the timed path avoids varying-shape metadata overlaunch. Add overallocated current-scaling coverage for all grouped FP8 direction modes. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_3f98ac9c5b82192ec289d8d2a9816c7f Orchestra-Run: run_83f3b99cc950024cf06ee836337fbf72
Stage columnwise transpose fragments through shared-memory vectors with smaller columnwise row tiles to reduce register pressure and barrier overhead while preserving the larger rowwise-only store path. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_495cc57eef84749103aded403a508d99 Orchestra-Run: run_53e038e90f83186bc6c12cb722c986b5
Add fast grouped FP8 rowwise and full-tile columnwise paths for uniform active groups while preserving the general fallback for varying grouped metadata. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_4c33e88776c8a7148e9da5cc2bae84ea Orchestra-Run: run_2caaff219394eb5d59b7be38ab2bf346
Add a same-shape bidirectional full-tile kernel with wider input vectors and rowwise stores while preserving the existing rowwise-only, columnwise-only, and fallback grouped paths. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_87cec01d94f053b53e3c79377ad379ab Orchestra-Run: run_ed48db00a730a4bf56530d551ecd350e
Route same-shape rowwise+columnwise grouped FP8 tensor-scaling quantization through the compact full-tile transpose schedule instead of the wide dynamic-shared-memory variant, preserving the existing single-direction and fallback paths. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_fdddd228a620039c024b4ecf43f3ab42 Orchestra-Run: run_30a2753eea9c893cb0fadb8233da8ce6
Hint the rowwise stores in the full-tile rowwise+columnwise grouped FP8 path as streaming global stores to reduce cache/writeback pressure without changing single-direction launch geometry. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_bf82020032e68276f4e47c65f62d97ae Orchestra-Run: run_754ea4c864f329c6f2003b413b723c43
Add graph-safe grouped FP8 tensor-scaling metadata, support varying last dimensions, preserve same-shape fast paths, adjust grouped FP8 columnwise allocation by architecture, and expand benchmark/test coverage for the reviewed shape cases. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d104e74844fbc3d3b1a98a8d96d76037 Orchestra-Run: run_1314e997c61ffb92ff7120b0b26f0318
Map varying-last columnwise tiles per group to avoid tile-alignment device errors, expand nonaligned boundary coverage, and restore same-shape benchmark baseline criteria. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_14e0e7973300d26f69550bc0aee21acc Orchestra-Run: run_2f42b8ba138ed8b2b4d9dc90b92caf85
Add grouped FP8 benchmark support for baseline-ref same-session reports and update the benchmark request to enforce same-shape baseline regression checks alongside the per-mode throughput thresholds. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d0cada957a4aafdce9d52be86520e182 Orchestra-Run: run_4da74e9bdb4f4a4c72304a385692b6c9
Update the grouped FP8 benchmark driver so same-session baseline checks out and builds the baseline ref into an isolated PyTorch install, verifies the baseline subprocess loads those shared objects, and preserves the required same-shape baseline comparisons. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_4fd88b172872f547f2f2d0053dce73d1 Orchestra-Run: run_6a44ee0467ffff47d4b278de6127354d
Preserve grouped delayed-FP8 amax metadata and keep unsupported FP8 tensor-scaling quantizers out of the grouped GEMM path. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_2aa8e6bf11ae356f4b34d4540b508031 Orchestra-Run: run_302681098d7f4e05b0ad96450f2d9826
Set NVTE_GROUPED_LINEAR_SINGLE_PARAM inside the targeted state-dict tests so they exercise the gated single grouped parameter path without relying on external environment setup. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_261900f987bdc9397965019983a77c41 Orchestra-Run: run_c6624e34717cbe121b3e0edcf490e3d3
Add a segmented flat rowwise kernel for varying-first grouped FP8 tensor-scaling outputs while preserving the existing same-shape fast path. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_c1b7020b27290318848ef6ac9048dd5f Orchestra-Run: run_5c257b8a5d2e7e4aa95e67aa16436166
Omit the last_dims keyword when absent so the same-session baseline can run against the base extension, and refresh the benchmark request to include direct varying-last current-scaling coverage. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_c20a3c94fdc798e741a469bd7bb9c4df Orchestra-Run: run_457448e6cba80fc63ac72b3db71c5fd0
Dispatch varying-first tensor-scaling work per group to reduce inactive-tail CTAs and offset lookup overhead while preserving same-shape fast paths and graph-safe device metadata handling. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d84e1fefef8641e558df064452f4689b Orchestra-Run: run_a361ca2f93fcec53ddd60dd99f4639e5
Add a no-tail rowwise flat kernel for aligned varying-first grouped FP8 tensor-scaling quantization and keep same-shape and varying-last dispatch isolated. Tighten benchmark profiler timing so post-warmup measured ranges exclude profiler start overhead. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_2e478be1fb38195f36d25c51320dc01f Orchestra-Run: run_9a133a75fa3d98dc3b1a63b0ff4d84af
Write grouped FP8 benchmark reports to a sidecar path by default and label script reports as benchmark_raw_report/v1 so regular 100-iteration measurements are fetched instead of the wrapper command report. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_27770b2e1d490b1a3053244d4b4ce248 Orchestra-Run: run_214052d0c1316e231443d645183a2675
Write the grouped FP8 benchmark JSON once and mirror the completed sidecar to ORCHESTRA_BENCHMARK_RAW_REPORT when running under Orchestra so the benchmark fetch path can parse the emitted measurements. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_b2e2747371204088c8e3f7cf10263164 Orchestra-Run: run_1d4ea38266807c8acb59143ee74ba241
Allow the grouped FP8 benchmark to use ORCHESTRA_BENCHMARK_RAW_REPORT as its primary output so the benchmark wrapper can fetch canonical measurements directly. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_10fdcfef6b70de4676b7843e4bbfac31 Orchestra-Run: run_4ce57df9e86d6d03a26f7aa95ac252cc
Write canonical grouped FP8 benchmark measurements to ORCHESTRA_BENCHMARK_RAW_REPORT in a small schema-shaped payload so the benchmark wrapper can materialize per-mode threshold evidence. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_3e862eebd585c74f2a58497fedea3511 Orchestra-Run: run_3770ab3dbbf51329d0839b3d10a91b5c
Write candidate_results and nonempty measurements into the Orchestra raw report path, and fail fast if the benchmark cannot produce threshold-ready evidence. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_aa587a7b0d35aa9c2b715ec1b7c8bec3 Orchestra-Run: run_b42870e5d5e142a6cbf53bb5a3cafc2e
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…TransformerEngine into current_scaling_group_quant
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryThis PR adds two major features: (1) a new FP8 current-scaling grouped-quantization path for
Confidence Score: 5/5Safe to merge; the end-to-end FP8 current-scaling grouped-quantization pipeline and the varying-last-dim dispatch are logically sound and correctly guarded against unsupported combinations. The core additions — flat amax scan, scale derivation, 2D-grid group-quantize dispatch for VARYING_LAST_DIM, and the noop-flag plumbing — are all correctly wired together. Previously reported bugs (positional arg shift in grouped_mlp.py, noop_flag forwarding, error message accuracy) have been addressed. The remaining findings are documentation and minor interface-clarity issues that do not affect correctness. group_amax_fp8.cuh (unused-parameter API surface) and grouped_tensor_storage.py (late failure for float8_current_scaling + VARYING_BOTH via Python path). Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python caller
participant C as group_quantize (C++)
participant CS as Float8CurrentScaling<br/>Quantizer::create_grouped_tensor
participant AM as nvte_group_compute_amax<br/>_with_config
participant SC as nvte_group_compute_scale<br/>_from_amax
participant QK as nvte_group_quantize<br/>(group_quantize_fp8.cuh)
PY->>C: "group_quantize(tensor, quantizer,<br/>num_tensors, first_dims, last_dims, noop_flag)"
C->>CS: create_grouped_tensor(first_dims, last_dims, ...)
CS-->>C: "grouped_output_cpp + grouped_output_py<br/>(amax/scale/scale_inv/data buffers)"
C->>AM: "nvte_group_compute_amax_with_config<br/>(input, output, config, stream)"
Note over AM: zero amax buffer,<br/>flat-scan over elements,<br/>atomicMax per tensor
AM-->>C: amax[0..num_tensors-1] filled
opt with_amax_reduction
C->>C: NCCL AllReduce MAX on amax buffer
end
C->>SC: "nvte_group_compute_scale_from_amax<br/>(output, config, stream)"
Note over SC: scale[i] = fp8_max / amax[i]<br/>scale_inv[i] = 1 / scale[i]
SC-->>C: scale / scale_inv filled
C->>QK: "nvte_group_quantize(input, output,<br/>quant_config, stream)"
Note over QK: dispatch on shape_rep<br/>(SAME/VARYING_FIRST/VARYING_LAST)<br/>read scale_ptr[tensor_id] per element
QK-->>C: FP8 quantized data written
C-->>PY: grouped_output_py
Reviews (7): Last reviewed commit: "fix comment and add nvte check" | Re-trigger Greptile |
… specific Signed-off-by: Varun Thumbe <vthumbe@vthumbe-mlt.client.nvidia.com>
|
/te-ci |
|
Pipeline: 54747206 |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…scale from amax Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
…TransformerEngine into current_scaling_group_quant
Removed duplicate brief comment about scaled prefix-sum offsets. Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci |
Signed-off-by: Varun Thumbe <vthumbe@vthumbe-mlt.client.nvidia.com>
|
/te-ci |
Oleg-Goncharov
left a comment
There was a problem hiding this comment.
Hi @vthumbe1503, could you please also add C++ unit tests covering various shapes, similar to grouped MXFP8?
Description
Performance from Benchmarking Script for Current Scaling Group Quantize.
NOTE:
GB200

H100



Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Core Changes
common.cuhhas been seperated out into two files --grouped_layout.cuhandgrouped_tma.cuh. The idea is that,grouped_layout.cuhhas generic utilities commonly used in grouped tensor kernels. Andgrouped_tma.cuhhas the arch specific TMA changes. This is done so that current scaling can still be non-arch specific and it can use just thegrouped_layout.cuhsince we didnt need TMA for current scaling.Other Common Changes
Pytorch Group Quantize API Changes
Benchmarking Scripts/Tests added for Current Scaling for Varying First/Last Dims and to handle overallocation of grouped tensors. Varying All Dims is not supported yet.
Checklist: