-
Notifications
You must be signed in to change notification settings - Fork 750
NVFP4: cache GEMM-swizzled weight scale factors across micro-batches #3093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0ee2ede
f04d800
e059449
88a8e84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| """Tests for the cached-weight scale-swizzle optimization. | ||
|
|
||
| For block-scaled NVFP4 a weight participates in two GEMMs per step: | ||
|
|
||
| * fprop: ``y = x @ Wt`` -> consumes the weight's **rowwise** scale factors | ||
| * dgrad: ``dx = dY @ W`` -> consumes the weight's **columnwise** scale factors | ||
|
|
||
| cuBLAS/CUTLASS needs those scale factors in a GEMM-"swizzled" layout. Without | ||
| ``optimize_for_gemm`` on the *weight* quantizer that swizzle is recomputed | ||
| lazily inside every GEMM and discarded, so with ``N`` micro-batches the weight | ||
| scale swizzle runs ``2*N`` times per step even though the weight is quantized | ||
| once. When the quantized weight is cached across micro-batches | ||
| (``is_first_microbatch`` is not ``None``) and FSDP is not in use, the module | ||
| sets ``weight_quantizer.optimize_for_gemm = True`` so the swizzle is done once | ||
| at quantize time, persisted on the cached workspace | ||
| (``_with_gemm_swizzled_scales = True``), and reused by every GEMM -> ``2`` | ||
| swizzles per step instead of ``2*N``. | ||
|
|
||
| These tests verify that: | ||
|
|
||
| 1. The optimization is **numerically a no-op**: swizzling is a pure layout | ||
| permutation of the scale factors, so the cached (eager-swizzle) path must | ||
| produce the same fprop output and dgrad as the un-cached (lazy-swizzle) | ||
| baseline, for every distinct micro-batch. | ||
| 2. The ``_with_gemm_swizzled_scales`` flag is actually set and persisted on the | ||
| cached weight workspace. | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| import transformer_engine.pytorch as te | ||
| from transformer_engine.common.recipe import NVFP4BlockScaling | ||
|
|
||
|
|
||
| recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) | ||
|
|
||
| pytestmark = pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) | ||
|
|
||
|
|
||
| def _make_module(kind, in_features, out_features, device, num_gemms=1): | ||
| common = dict(bias=True, params_dtype=torch.bfloat16) | ||
| if kind == "Linear": | ||
| return te.Linear(in_features, out_features, **common).to(device) | ||
| if kind == "LayerNormLinear": | ||
| return te.LayerNormLinear(in_features, out_features, **common).to(device) | ||
| if kind == "LayerNormMLP": | ||
| # fc1 (in->ffn) and fc2 (ffn->in) each cache a weight, so two workspaces. | ||
| return te.LayerNormMLP(in_features, out_features, **common).to(device) | ||
| if kind == "GroupedLinear": | ||
| return te.GroupedLinear(num_gemms, in_features, out_features, **common).to(device) | ||
| raise ValueError(f"unknown module kind {kind}") | ||
|
|
||
|
|
||
| def _expected_num_workspaces(kind, num_gemms): | ||
| """Number of cached weight workspaces a module populates on the cached path.""" | ||
| if kind == "GroupedLinear": | ||
| return num_gemms # one per expert | ||
| if kind == "LayerNormMLP": | ||
| return 2 # fc1 + fc2 | ||
| return 1 | ||
|
|
||
|
|
||
| def _clone_params(src, dst): | ||
| """Copy src's parameters into dst so both modules start identical.""" | ||
| with torch.no_grad(): | ||
| dst_params = dict(dst.named_parameters()) | ||
| for name, param in src.named_parameters(): | ||
| dst_params[name].copy_(param) | ||
|
|
||
|
|
||
| def _step(module, x, is_first, recipe, m_splits=None): | ||
| x = x.detach().clone().requires_grad_(True) | ||
| module.zero_grad(set_to_none=True) # per-micro-batch grads (no accumulation) | ||
| with te.autocast(enabled=True, recipe=recipe): | ||
| # Passing m_splits is the only call-site difference for GroupedLinear; the | ||
| # rest of the cached-weight numerics check is identical across modules. | ||
| if m_splits is None: | ||
| out = module(x, is_first_microbatch=is_first) | ||
| else: | ||
| out = module(x, m_splits, is_first_microbatch=is_first) | ||
| out.sum().backward() | ||
| return out.detach().float(), x.grad.detach().float() | ||
|
|
||
|
|
||
| _MODULE_KINDS = ["Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear"] | ||
|
|
||
|
|
||
| def _grouped_m_splits(kind, batch, num_gemms): | ||
| """m_splits for GroupedLinear (even token split across experts), else None.""" | ||
| if kind != "GroupedLinear": | ||
| return None | ||
| return [batch // num_gemms] * num_gemms | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("kind", _MODULE_KINDS) | ||
| @pytest.mark.parametrize("microbatches", [1, 4]) | ||
| @pytest.mark.parametrize("shape", [(1024, 1024), (2048, 512)], ids=["1024x1024", "2048x512"]) | ||
| def test_weight_swizzle_cache_numerics(kind, microbatches, shape): | ||
| """Cached eager-swizzle path == lazy-swizzle baseline (fprop + dgrad). | ||
|
|
||
| Shared across all module kinds: GroupedLinear only differs by passing | ||
| m_splits (handled in _step), LayerNormMLP only by caching two weights. | ||
| """ | ||
| torch.manual_seed(1234) | ||
| device = "cuda" | ||
| in_features, out_features = shape | ||
| batch = 512 | ||
| num_gemms = 2 if kind == "GroupedLinear" else 1 | ||
| m_splits = _grouped_m_splits(kind, batch, num_gemms) | ||
|
|
||
| # Stochastic rounding is the only run-to-run nondeterminism source (RHT uses | ||
| # a fixed sign mask) and it is applied to the bwd grad regardless of this | ||
| # optimization, so disable it to make eager-vs-lazy weight swizzle | ||
| # bit-comparable. The swizzle is a pure layout transform, so with SR off the | ||
| # two paths must match tightly. | ||
| recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) | ||
|
|
||
| # ref: always lazy-swizzle (is_first_microbatch=None => no weight cache => | ||
| # optimize_for_gemm stays False). opt: cached eager-swizzle path. Identical | ||
| # weights so per-micro-batch outputs are directly comparable. | ||
| ref = _make_module(kind, in_features, out_features, device, num_gemms) | ||
| opt = _make_module(kind, in_features, out_features, device, num_gemms) | ||
| _clone_params(ref, opt) | ||
|
|
||
| # Distinct inputs per micro-batch (mirrors gradient accumulation: different | ||
| # data each micro-batch, same weight). | ||
| inputs = [ | ||
| torch.randn(batch, in_features, dtype=torch.bfloat16, device=device) | ||
| for _ in range(microbatches) | ||
| ] | ||
|
|
||
| atol, rtol = 1e-3, 1e-3 | ||
| for mb in range(microbatches): | ||
| ref_out, ref_dgrad = _step(ref, inputs[mb], None, recipe, m_splits) | ||
| opt_out, opt_dgrad = _step(opt, inputs[mb], mb == 0, recipe, m_splits) | ||
| torch.testing.assert_close( | ||
| opt_out, ref_out, atol=atol, rtol=rtol, msg=f"fprop mismatch at mb {mb}" | ||
| ) | ||
| torch.testing.assert_close( | ||
| opt_dgrad, ref_dgrad, atol=atol, rtol=rtol, msg=f"dgrad mismatch at mb {mb}" | ||
| ) | ||
|
|
||
| # The swizzled flag must be set & persisted on every cached weight workspace | ||
| # (one per expert for GroupedLinear, fc1+fc2 for LayerNormMLP, else one). | ||
| workspaces = opt._fp8_workspaces | ||
| assert len(workspaces) == _expected_num_workspaces( | ||
| kind, num_gemms | ||
| ), f"unexpected cached weight workspace count for {kind}: {len(workspaces)}" | ||
| for name, ws in workspaces.items(): | ||
| assert getattr(ws, "_with_gemm_swizzled_scales", False) is True, ( | ||
| f"cached weight workspace {name!r} scales were not pre-swizzled " | ||
| "(optimize_for_gemm not applied)" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("kind", _MODULE_KINDS) | ||
| def test_lazy_path_not_swizzled(kind): | ||
| """Without weight caching (is_first_microbatch=None) no workspace is created | ||
| and the optimization stays off — guards against accidentally always-on.""" | ||
| torch.manual_seed(0) | ||
| device = "cuda" | ||
| batch = 512 | ||
| num_gemms = 2 if kind == "GroupedLinear" else 1 | ||
| m_splits = _grouped_m_splits(kind, batch, num_gemms) | ||
| recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) | ||
| module = _make_module(kind, 1024, 1024, device, num_gemms) | ||
| x = torch.randn(batch, 1024, dtype=torch.bfloat16, device=device, requires_grad=True) | ||
| with te.autocast(enabled=True, recipe=recipe): | ||
| if m_splits is None: | ||
| out = module(x, is_first_microbatch=None) | ||
| else: | ||
| out = module(x, m_splits, is_first_microbatch=None) | ||
| out.sum().backward() | ||
| assert ( | ||
| not module._fp8_workspaces | ||
| ), "lazy path (is_first_microbatch=None) must not populate the weight cache" | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1735,6 +1735,16 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||||
| else [None] * num_gemms | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Pre-swizzle (and cache) the weight scale factors when the quantized | ||||||||||||||||||||||||||||||||||||||||||||
| # weights are cached across microbatches, so the per-GEMM scale swizzle | ||||||||||||||||||||||||||||||||||||||||||||
| # (fprop rowwise + dgrad columnwise, redone every microbatch) collapses | ||||||||||||||||||||||||||||||||||||||||||||
| # from 2*num_microbatches kernels to 2 per step per expert. | ||||||||||||||||||||||||||||||||||||||||||||
| # No-op for non-swizzled recipes (e.g. per-tensor FP8). | ||||||||||||||||||||||||||||||||||||||||||||
| if cache_weight: | ||||||||||||||||||||||||||||||||||||||||||||
| for weight_quantizer in weight_quantizers: | ||||||||||||||||||||||||||||||||||||||||||||
| if weight_quantizer is not None: | ||||||||||||||||||||||||||||||||||||||||||||
| weight_quantizer.optimize_for_gemm = True | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1738
to
+1747
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont think the comment is relevant In case of FSDP/FSDP2,
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same applies in other files.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Applied to all four module files. |
||||||||||||||||||||||||||||||||||||||||||||
| non_tensor_args = ( | ||||||||||||||||||||||||||||||||||||||||||||
| self.apply_bias, | ||||||||||||||||||||||||||||||||||||||||||||
| is_first_microbatch, | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LayerNormMLPtest coveragelayernorm_mlp.pyis one of four files modified by this PR, yet the test suite parametrizes only over["Linear", "LayerNormLinear"]for bothtest_weight_swizzle_cache_numericsandtest_lazy_path_not_swizzled. The fc1/fc2 two-quantizer path inLayerNormMLPis structurally different from the single-quantizer modules: it independently gatesfc1_weight_quantizer.optimize_for_gemmandfc2_weight_quantizer.optimize_for_gemmusing separatecache_name_fc1/cache_name_fc2variables. If either gating expression were wrong (e.g. swapping fc1/fc2 names), existing tests would not catch it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added LayerNormMLP coverage (fc1+fc2 two-quantizer path) to both parametrized tests.