[Pytorch][Common] Hybrid quantization#2817
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces hybrid (per-direction) quantization, allowing different quantization formats for rowwise vs. columnwise GEMM directions — enabling, for example, MXFP8 forward + NVFP4 backward without hardcoding every format combination. It also adds a high-precision passthrough
Confidence Score: 4/5Safe to merge for the claimed-working paths (1-GPU, DP + dist-opt, FSDP2, CPU offload, activation recompute, TP/SP, MoE); the two missing activation_dtype casts in _hybrid_split_quantize are non-blocking style inconsistencies. The hybrid quantization machinery is well-designed and the claimed integration paths are covered by extensive bitwise tests. Two call sites in GroupedLinear's backward pass pass grad_output and wgrad input to _hybrid_split_quantize without the explicit activation_dtype cast that the non-hybrid path applies; this is a minor numerical consistency gap, not a correctness failure, but it should be addressed before the hybrid grouped-linear backward is used in precision-sensitive mixed-dtype training. transformer_engine/pytorch/module/grouped_linear.py — the hybrid backward quantization paths (grad_output split at ~line 1187 and wgrad input split at ~line 1320) are missing the activation_dtype cast present in every other split-quantize call site. Important Files Changed
Class Diagram%%{init: {'theme': 'neutral'}}%%
classDiagram
class Quantizer {
+rowwise_usage: bool
+columnwise_usage: bool
+quantize(tensor)
+make_empty(shape)
+update_quantized(src, dst)
}
class HybridQuantizer {
+rowwise_quantizer: Quantizer
+columnwise_quantizer: Quantizer
+quantize_impl(tensor) HybridQuantizedTensor
+supports_only_rowwise_all_gather() bool
+with_amax_reduction: bool
}
class QuantizedTensorStorage {
+fsdp_buffer_fields()
+fsdp_extract_buffers()
+fsdp_assign_gathered()
+prepare_for_saving()
+restore_from_saved()
}
class HybridQuantizedTensorStorage {
+_rowwise_storage: QuantizedTensorStorage
+_columnwise_storage: QuantizedTensorStorage
+_quantizer: HybridQuantizer
+update_usage()
+dequantize()
+view()
}
class HybridQuantizedTensor {
+fsdp_pre_all_gather()
+fsdp_post_all_gather()
+detach() HybridQuantizedTensor
+__torch_dispatch__()
+__reduce_ex__()
}
class IdentityQuantizer {
+dtype: Optional~torch.dtype~
+quantize_impl(tensor) IdentityTensor
}
class Float8TensorStorage {
+fsdp_buffer_fields() direction-aware
+fsdp_assign_gathered() clears _transpose_invalid
+view() null-guards _data
}
class MXFP8TensorStorage {
+fsdp_buffer_fields()
+fsdp_extract_buffers() strips scale padding
+fsdp_assign_gathered() re-pads scales
}
Quantizer <|-- HybridQuantizer
Quantizer <|-- IdentityQuantizer
QuantizedTensorStorage <|-- HybridQuantizedTensorStorage
HybridQuantizedTensorStorage <|-- HybridQuantizedTensor
HybridQuantizer "1" --> "2" Quantizer : owns rowwise + columnwise
HybridQuantizedTensorStorage "1" --> "2" QuantizedTensorStorage : rowwise + columnwise sub
QuantizedTensorStorage <|-- Float8TensorStorage
QuantizedTensorStorage <|-- MXFP8TensorStorage
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
classDiagram
class Quantizer {
+rowwise_usage: bool
+columnwise_usage: bool
+quantize(tensor)
+make_empty(shape)
+update_quantized(src, dst)
}
class HybridQuantizer {
+rowwise_quantizer: Quantizer
+columnwise_quantizer: Quantizer
+quantize_impl(tensor) HybridQuantizedTensor
+supports_only_rowwise_all_gather() bool
+with_amax_reduction: bool
}
class QuantizedTensorStorage {
+fsdp_buffer_fields()
+fsdp_extract_buffers()
+fsdp_assign_gathered()
+prepare_for_saving()
+restore_from_saved()
}
class HybridQuantizedTensorStorage {
+_rowwise_storage: QuantizedTensorStorage
+_columnwise_storage: QuantizedTensorStorage
+_quantizer: HybridQuantizer
+update_usage()
+dequantize()
+view()
}
class HybridQuantizedTensor {
+fsdp_pre_all_gather()
+fsdp_post_all_gather()
+detach() HybridQuantizedTensor
+__torch_dispatch__()
+__reduce_ex__()
}
class IdentityQuantizer {
+dtype: Optional~torch.dtype~
+quantize_impl(tensor) IdentityTensor
}
class Float8TensorStorage {
+fsdp_buffer_fields() direction-aware
+fsdp_assign_gathered() clears _transpose_invalid
+view() null-guards _data
}
class MXFP8TensorStorage {
+fsdp_buffer_fields()
+fsdp_extract_buffers() strips scale padding
+fsdp_assign_gathered() re-pads scales
}
Quantizer <|-- HybridQuantizer
Quantizer <|-- IdentityQuantizer
QuantizedTensorStorage <|-- HybridQuantizedTensorStorage
HybridQuantizedTensorStorage <|-- HybridQuantizedTensor
HybridQuantizer "1" --> "2" Quantizer : owns rowwise + columnwise
HybridQuantizedTensorStorage "1" --> "2" QuantizedTensorStorage : rowwise + columnwise sub
QuantizedTensorStorage <|-- Float8TensorStorage
QuantizedTensorStorage <|-- MXFP8TensorStorage
Reviews (14): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
timmoon10
left a comment
There was a problem hiding this comment.
Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | ||
| columnwise_result = self.columnwise_quantizer.quantize(tensor) |
There was a problem hiding this comment.
Do we handle the case where not all usages are needed? I'd expect something like:
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) | |
| rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None |
| requires_grad: bool = False, | ||
| pin_memory: bool = False, | ||
| ) -> HybridQuantizedTensor: | ||
| self.rowwise_quantizer.internal = True |
There was a problem hiding this comment.
Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.
There was a problem hiding this comment.
This would not work under FSDP2.
| def factory(role): | ||
| if role == "linear_weight": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_mxfp8_quantizer(), | ||
| ) | ||
| if role == "linear_input": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| if role in ("linear_grad_output", "linear_grad_input"): | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_mxfp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| return None |
There was a problem hiding this comment.
This is horrifying. Good test.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
| outs = [ | ||
| Float8Tensor.make_like( | ||
| tensor, | ||
| data=split_tensor, | ||
| data_transpose=split_transpose_tensor, | ||
| shape=split_tensor.shape, | ||
| shape=( | ||
| split_tensor.shape | ||
| if split_tensor is not None | ||
| else split_transpose_tensor.shape | ||
| ), | ||
| ) | ||
| for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) | ||
| ] |
There was a problem hiding this comment.
When
_data is None (columnwise-only sub-storage of a HybridQuantizedTensor on non-Hopper), the split falls back to split_transpose_tensor.shape, which is the transposed layout's shape [K, M/n]. The correct nominal shape for the shard is [M/n, K]. This wrong nominal shape propagates into the HybridQuantizedTensor through fsdp_post_all_gather (which calls _infer_shape on the gathered _transpose buffer to build col_sub), so after the first FSDP2 iteration the assembled full-parameter hybrid's _columnwise_storage reports [K, M] instead of [M, K]. Any Python-side code that calls .size() on that sub-storage (e.g., HybridQuantizedTensorStorage.size() when rowwise is also None, workspace-validity checks, debugging assertions) will see the wrong dimensions.
| outs = [ | |
| Float8Tensor.make_like( | |
| tensor, | |
| data=split_tensor, | |
| data_transpose=split_transpose_tensor, | |
| shape=split_tensor.shape, | |
| shape=( | |
| split_tensor.shape | |
| if split_tensor is not None | |
| else split_transpose_tensor.shape | |
| ), | |
| ) | |
| for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) | |
| ] | |
| outs = [ | |
| Float8Tensor.make_like( | |
| tensor, | |
| data=split_tensor, | |
| data_transpose=split_transpose_tensor, | |
| shape=( | |
| split_tensor.shape | |
| if split_tensor is not None | |
| # _transpose has shape [K, M/n] but the shard's nominal shape | |
| # is [M/n, K]. Recover the correct shard shape by reversing | |
| # the last two dims of the transposed piece. | |
| else (*split_transpose_tensor.shape[1:], split_transpose_tensor.shape[0]) | |
| ), | |
| ) | |
| for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) | |
| ] |
| # DCP serializes ``CustomRecipe`` via ``pickle``; closure-based qfactories | ||
| # (lambdas, inner functions referencing captured state) are not picklable, | ||
| # so the qfactory must live at module scope. See | ||
| # ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``. |
There was a problem hiding this comment.
This comment is potentially useful, but I don't think it is in the right place - shouldn't it be closer to the actual implementation?
| for param in model.parameters(): | ||
| state = optimizer.state[param] | ||
| assert state["exp_avg"].dtype == torch.float32 | ||
| assert state["exp_avg_sq"].dtype == torch.float32 | ||
| if "master_param" in state: | ||
| assert state["master_param"].dtype == torch.float32 | ||
|
|
||
| assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}" |
There was a problem hiding this comment.
That's not a very strict test, is there a way for us to do some numerical correctness comparisons?
There was a problem hiding this comment.
Enabled check for the monotonic loss decrease (still mostly sanity), and also enabled hybrid vs vanilla bitwise recipe comparizon, see e.g. test_fused_adam_hybrid_vs_base_recipe_parity.
| """Default NVFP4Quantizer: no RHT, no stochastic rounding, no 2D | ||
| scaling — matches upstream ``run_numerics.py::nvfp4_vanilla()`` which | ||
| strips the recipe to bare 1D block scaling for distributed TP | ||
| fairness. Same-format hybrid NVFP4 has no E5M2 variant (NVFP4 is a | ||
| single format family — E2M1 only), so grad roles reuse the same | ||
| NVFP4 quantizer.""" |
There was a problem hiding this comment.
Why don't we want to check the full recipe here?
There was a problem hiding this comment.
Switched to the full recipe except 1D for weights, will enabled after #3027 merge
| is_linear = role is not None and role.module_type in ("linear", "grouped_linear") | ||
| if is_linear and role.tensor_type in ("input", "weight", "output"): | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_nvfp4_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| if is_linear and role.tensor_type in ("grad_output", "grad_input"): | ||
| return _make_nvfp4_quantizer() |
There was a problem hiding this comment.
As written those lines are not needed at all. They would be needed if you did the full recipe.
There was a problem hiding this comment.
Switched to the full recipe
| # quantization (rowwise and columnwise quantizers run independently, so | ||
| # their outputs may differ by ~1 ULP from a single fused-quantize path | ||
| # in edge cases). |
There was a problem hiding this comment.
That does not sound like a good thing if it actually happens in practice - the quantization only should not be affected if you do both at the same time vs one at a time -> the input and the algorithm is the same in both cases. Fusion with the activations could maybe give slightly different results, but I would still like to get an explanation of why that would be.
There was a problem hiding this comment.
You are right. If the algorithm is the same, we are indeed getting identical results. New bitwise linear_vs_vanilla test confirms this. The only place where two pass and fused differ is NVFP4 with RHT + SR. This activates a separate columnwise RNG (need_separate_columnwise_rng), and RNG stream consumed differently. see comment in _backward_not_bitwise_comparable(). Removed the comment.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Description
Hybrid (per-direction) quantization.
The main problem that it tries to solve is that precision requirements are non-uniform.
Current recipes set one format for both rowwise and colwise directions.
Hybrid quantization enables, e.g. MXFP8 fwd and NVFP4 bwd (or vice versa) or any other valid combination. No need for a hardcoded recipe for every combination.
MXFP8 fwd + NVFP4 bwd:
C++ optimizations (fusions, etc.) will come as standalone PRs. cc @kainzhong
TODO:
Integration
Ecosystem integration (all functional, unit-tested):
Megatron-LM integration status:
--fp{4,8}-param-gather+ dist opt (persistent low-precision params viaquantized_model_init+ sharded-master FP32 → quantized cast viaquantize_master_weights.)- [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
including same-format, cross-format Float8, single-direction)
- [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by
quantize_master_weights; unblocker is TE-side cast-helper / kernel.--fp{4,8}-param-gather(fix private attribute access)--fp{4,8}-param-gather- [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
- [TODO] NVFP4 sub-storage FSDP2 hooks
_hybrid_split_quantizeunder Megatron MoE)Review
Total diff +9000
New hybrid source (
hybrid_tensor.py,hybrid_tensor_storage.py) ~1000Adjacent modifications ~1000
Tests are the rest
Surface to review is ~2000 lines
Suggested reading order
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: