Skip to content

[Pytorch][Common] Hybrid quantization#2817

Open
negvet wants to merge 37 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization
Open

[Pytorch][Common] Hybrid quantization#2817
negvet wants to merge 37 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization

Conversation

@negvet

@negvet negvet commented Mar 31, 2026

Copy link
Copy Markdown
Collaborator

Description

Hybrid (per-direction) quantization.
The main problem that it tries to solve is that precision requirements are non-uniform.

Current recipes set one format for both rowwise and colwise directions.
Hybrid quantization enables, e.g. MXFP8 fwd and NVFP4 bwd (or vice versa) or any other valid combination. No need for a hardcoded recipe for every combination.

MXFP8 fwd + NVFP4 bwd:

# CustomRecipe calls quantization_factory(role) for each quantized tensor
# Factory choose separate rowwise and columnwise formats whenever needed

def hybrid_factory(role):
    if role.tensor_type in ("grad_output", "grad_input"):
        return NVFP4Quantizer(...)

    return HybridQuantizer(
        rowwise_quantizer=MXFP8Quantizer(...),
        columnwise_quantizer=NVFP4Quantizer(...),
    )


recipe = CustomRecipe(qfactory=hybrid_factory)
with autocast(recipe=recipe):
    y = model(x)

C++ optimizations (fusions, etc.) will come as standalone PRs. cc @kainzhong

TODO:

  • Double quantization
  • Non-hybrid convergence of base recipes (validation)
  • HybridFloat8BlockScaling is xfailed under FSDP2 because dim-0 shards can split 128-row block-scale tiles, producing all-gathered scale buffers whose shape does not match the global tensor.
  • Delayed scaling
  • Mid-training recipe change

Integration

Ecosystem integration (all functional, unit-tested):

  • [Done] quantized_model_init
  • [Done] FSDP2 (TODO: optimize communication buffers)
  • [Done] CPU offloading
  • [Done] Activation recomputation
  • [Done] TP/SP (TODO: enable quantized AG)

Megatron-LM integration status:

  • [Done] 1 GPU baseline
  • [Done] DP + distributed optimizer
  • [TODO] quantized_model_init + --fp{4,8}-param-gather + dist opt (persistent low-precision params via quantized_model_init + sharded-master FP32 → quantized cast via quantize_master_weights.)
    - [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
    including same-format, cross-format Float8, single-direction)
    - [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by quantize_master_weights; unblocker is TE-side cast-helper / kernel.
  • [TODO] Megatron-FSDP + --fp{4,8}-param-gather (fix private attribute access)
  • [TODO] Torch FSDP2 + --fp{4,8}-param-gather
    - [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
    - [TODO] NVFP4 sub-storage FSDP2 hooks
  • [Done] Activation recompute
  • [Done] CPU offload
  • [Done] TP/SP/PP
  • [Done] MoE + EP + grouped GEMM (qwen3 MoE; _hybrid_split_quantize under Megatron MoE)

Review

Total diff +9000
New hybrid source (hybrid_tensor.py, hybrid_tensor_storage.py) ~1000
Adjacent modifications ~1000
Tests are the rest

Surface to review is ~2000 lines

Suggested reading order

  1. Foundation — 7553e6a: Python containers + quantize/gemm dispatch/unwrap
  • tensor/hybrid_tensor.py — HybridQuantizer + HybridQuantizedTensor
  • tensor/storage/hybrid_tensor_storage.py
  • cpp_extensions/gemm.py — _unwrap_hybrid_A/B
  • common/transpose/quantize_transpose_square_blockwise.cu - Block FP8 columnwise-only null-checks
  • Module hooks in module/{base,grouped_linear,layernorm_linear,layernorm_mlp}.py
  • Tests: TestHybridQuantizer*, TestHybridGemmBitwiseIdentical* (proves zero-overhead vs vanilla recipes when both formats match), TestHybridDirectionUnwrap*, TestHybridGroupedLinear*
  1. quantized_model_init + FusedAdam — f80f5d0
  • hybrid_tensor.py::HybridQuantizer.update_quantized — delegates to each sub-quantizer; unblocks workspace-cache quantize_() and FusedAdam writeback
  • module/base.py workspace-cache invalidation
  • Tests: TestHybridQuantizedModelInit, TestHybridFusedAdam, TestHybridQuantizedParamsEndToEnd, TestHybridCheckpoint, TestQuantizedParamsEquivalence*
  1. FSDP2 support — 2185b30
  • New base FSDP2 buffer protocol on QuantizedTensorStorage: fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered. Generic, reusable beyond hybrid.
  • Per-format overrides on Float8TensorStorage (direction-aware) and MXFP8TensorStorage (trips/re-applies scale alignment padding around the gather)
  • hybrid_tensor.py::fsdp_pre/post_all_gather + torch_dispatch for the FSDP2 op set (view, split, as_strided, slice, copy_, new_zeros, clone, detach)
  • Non-safety in float8_tensor.py and mxfp8_tensor.py for single-direction sub-storages (columnwise-only on Hopper/L40)
  • Tests: TestHybridTorchDispatchFSDP2Ops, TestHybridFsdpPreAllGatherProtocol, TestHybridFsdpRoundtrip (bitwise-exact against manual all_gather(dequantize(shard))), plus tests/pytorch/distributed/fsdp2_tests/
  1. CPU offloading — 103fffe
  • hybrid_tensor_storage.py::clear() (v1 path) + prepare_for_saving / restore_from_saved chain (v2 path)
  • hybrid_tensor.py::detach() re-wraps each sub-storage via make_like (required by cpu_offload_v2's detach → prepare_for_saving pattern; sharing sub-storage objects would null-out fields on the original)
  • TestHybridCpuOffloadPushPop, plus updates to test_cpu_offloading*.py
  1. Activation recomputation — 16fb371
  • Uses existing QuantizedTensorStorage::prepare_for_saving / restore_from_saved protocol, preserving ordering across both sub-storages
  • Tests: 20 bitwise tests in TestHybridActivationRecompute
  1. TP/SP — a50fd63
  • hybrid_tensor.py::HybridQuantizer.supports_only_rowwise_all_gather — overrides to handle the NVFP4 columnwise-dequantize gap in the BF16 fallback path
  • distributed.py::gather_along_first_dim — hybrid branch re-quantizes with both directions after AG (since hybrid has no _create_transpose synthesis path)
  • Tests: 9 distributed tests in run_hybrid_tp_sp.py / test_hybrid_tp_sp.py
  1. Megatron-LM integration — a164cd3
  • tensor/utils.py::_route_hybrid_to_buckets — per-direction dispatch for quantize_master_weights: iterates both sub-storages, routes each independently into the per-format bucket matching its own sub-quantizer type
  • Hybrid branches in replace_raw_data and post_all_gather_processing
  • Today: per-tensor Float8 sub-quantizers (delayed + current) work in any per-direction combination. Per-block sub-quantizers raise per-direction with in-code TODOs naming the unblocker.
  • Tests: TestHybridQuantizeMasterWeights, TestHybridPostAllGatherProcessing

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented Mar 31, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces hybrid (per-direction) quantization, allowing different quantization formats for rowwise vs. columnwise GEMM directions — enabling, for example, MXFP8 forward + NVFP4 backward without hardcoding every format combination. It also adds a high-precision passthrough IdentityQuantizer/IdentityTensor pair for unquantized directions inside a hybrid recipe.

  • Core containers (hybrid_tensor.py, hybrid_tensor_storage.py): HybridQuantizer composes two sub-quantizers and pins each to its direction; HybridQuantizedTensor wraps two sub-storages with full support for pickling, CPU offload, activation recomputation, and the complete FSDP2 buffer protocol (fsdp_pre/post_all_gather, fsdp_buffer_fields, fsdp_extract_buffers, fsdp_assign_gathered).
  • Ecosystem integration: GEMM dispatch unwraps hybrid operands direction-appropriately (_unwrap_hybrid_A/B); GroupedLinear adds _hybrid_split_quantize for grouped MoE GEMM; quantize_master_weights gains _route_hybrid_to_buckets for distributed-optimizer support; Float8Tensor/Float8TensorStorage/MXFP8TensorStorage gain null-guards and direction-aware FSDP2 helpers for Hopper columnwise-only sub-storages.
  • Tests: ~7 400 lines covering bitwise-identity against vanilla recipes, FSDP2 round-trip, CPU offload, activation recompute, TP/SP, distributed optimizer, and checkpoint.

Confidence Score: 4/5

Safe to merge for the claimed-working paths (1-GPU, DP + dist-opt, FSDP2, CPU offload, activation recompute, TP/SP, MoE); the two missing activation_dtype casts in _hybrid_split_quantize are non-blocking style inconsistencies.

The hybrid quantization machinery is well-designed and the claimed integration paths are covered by extensive bitwise tests. Two call sites in GroupedLinear's backward pass pass grad_output and wgrad input to _hybrid_split_quantize without the explicit activation_dtype cast that the non-hybrid path applies; this is a minor numerical consistency gap, not a correctness failure, but it should be addressed before the hybrid grouped-linear backward is used in precision-sensitive mixed-dtype training.

transformer_engine/pytorch/module/grouped_linear.py — the hybrid backward quantization paths (grad_output split at ~line 1187 and wgrad input split at ~line 1320) are missing the activation_dtype cast present in every other split-quantize call site.

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/hybrid_tensor.py New HybridQuantizer + HybridQuantizedTensor; implements quantize_impl, make_empty, update_quantized, FSDP2 pre/post all-gather, torch_dispatch ops, CPU offload detach, and pickling. Core logic looks sound; fsdp_post_all_gather shape inference for purely columnwise-only hybrid is a known edge-case tracked in previous threads.
transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py New HybridQuantizedTensorStorage; delegates prepare_for_saving, restore_from_saved, dequantize, view, and FSDP2 buffer protocol to sub-storages. repr and update_usage behaviour look correct after previous-thread fixes.
transformer_engine/pytorch/tensor/utils.py Adds _route_hybrid_to_buckets, _update_transpose_only_float8_flat_fragment, _cast_master_weights_to_identity, and identity/hybrid branches in quantize_master_weights. The transpose-only scatter loop is correct; Hopper columnwise-only crash on the use_fsdp_shard_model_weights=False path is now guarded. The use_fsdp_shard_model_weights=True branch remains unguarded but only affects [TODO] Megatron-FSDP paths.
transformer_engine/pytorch/module/grouped_linear.py Adds _hybrid_split_quantize, _is_hybrid_quantizer_list, identity-quantizer helpers, and hybrid/identity dispatch branches in forward and backward. The activation_dtype cast is missing in _hybrid_split_quantize for grad_output and wgrad input paths (non-blocking, P2).
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Adds direction-aware fsdp_buffer_fields (returns _transpose when _data is None), fsdp_assign_gathered (clears _transpose_invalid), and null-guards in view/_create_transpose for columnwise-only sub-storages. Addresses previously-raised FSDP2 buffer correctness issues.
transformer_engine/pytorch/tensor/float8_tensor.py Null-guards _data throughout view, split, clone, remove_caches for Hopper columnwise-only sub-storages; adds _canonical_view_shape and _columnwise_shape_for helpers. The split shape fix correctly uses the transposed buffer dimensions to derive the correct logical shard shape.
transformer_engine/pytorch/cpp_extensions/gemm.py Adds _unwrap_hybrid_A/B and _materialize_high_precision helpers that extract the direction-appropriate sub-storage before the C++ GEMM call; also adds _reject_unsupported_output_quantizer guard for HybridQuantizer/IdentityQuantizer output slots.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Adds fsdp_buffer_fields, fsdp_extract_buffers (strips block-scale alignment padding), and fsdp_assign_gathered (re-pads scales) for MXFP8 hybrid FSDP2 support. Scale padding logic is correct for both rowwise [128,4] and columnwise [4,128] layouts.
transformer_engine/pytorch/distributed.py Adds hybrid branch in gather_along_first_dim that temporarily forces (rowwise=True, columnwise=True) on the HybridQuantizer before quantizing the all-gathered BF16 output, then restores original flags in a finally block. Correct.
transformer_engine/pytorch/tensor/identity_tensor.py New IdentityQuantizer / IdentityTensor high-precision passthrough; full quantize/dequantize/FSDP2/CPU offload/pickle protocol implemented.

Class Diagram

%%{init: {'theme': 'neutral'}}%%
classDiagram
    class Quantizer {
        +rowwise_usage: bool
        +columnwise_usage: bool
        +quantize(tensor)
        +make_empty(shape)
        +update_quantized(src, dst)
    }
    class HybridQuantizer {
        +rowwise_quantizer: Quantizer
        +columnwise_quantizer: Quantizer
        +quantize_impl(tensor) HybridQuantizedTensor
        +supports_only_rowwise_all_gather() bool
        +with_amax_reduction: bool
    }
    class QuantizedTensorStorage {
        +fsdp_buffer_fields()
        +fsdp_extract_buffers()
        +fsdp_assign_gathered()
        +prepare_for_saving()
        +restore_from_saved()
    }
    class HybridQuantizedTensorStorage {
        +_rowwise_storage: QuantizedTensorStorage
        +_columnwise_storage: QuantizedTensorStorage
        +_quantizer: HybridQuantizer
        +update_usage()
        +dequantize()
        +view()
    }
    class HybridQuantizedTensor {
        +fsdp_pre_all_gather()
        +fsdp_post_all_gather()
        +detach() HybridQuantizedTensor
        +__torch_dispatch__()
        +__reduce_ex__()
    }
    class IdentityQuantizer {
        +dtype: Optional~torch.dtype~
        +quantize_impl(tensor) IdentityTensor
    }
    class Float8TensorStorage {
        +fsdp_buffer_fields() direction-aware
        +fsdp_assign_gathered() clears _transpose_invalid
        +view() null-guards _data
    }
    class MXFP8TensorStorage {
        +fsdp_buffer_fields()
        +fsdp_extract_buffers() strips scale padding
        +fsdp_assign_gathered() re-pads scales
    }
    Quantizer <|-- HybridQuantizer
    Quantizer <|-- IdentityQuantizer
    QuantizedTensorStorage <|-- HybridQuantizedTensorStorage
    HybridQuantizedTensorStorage <|-- HybridQuantizedTensor
    HybridQuantizer "1" --> "2" Quantizer : owns rowwise + columnwise
    HybridQuantizedTensorStorage "1" --> "2" QuantizedTensorStorage : rowwise + columnwise sub
    QuantizedTensorStorage <|-- Float8TensorStorage
    QuantizedTensorStorage <|-- MXFP8TensorStorage
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
classDiagram
    class Quantizer {
        +rowwise_usage: bool
        +columnwise_usage: bool
        +quantize(tensor)
        +make_empty(shape)
        +update_quantized(src, dst)
    }
    class HybridQuantizer {
        +rowwise_quantizer: Quantizer
        +columnwise_quantizer: Quantizer
        +quantize_impl(tensor) HybridQuantizedTensor
        +supports_only_rowwise_all_gather() bool
        +with_amax_reduction: bool
    }
    class QuantizedTensorStorage {
        +fsdp_buffer_fields()
        +fsdp_extract_buffers()
        +fsdp_assign_gathered()
        +prepare_for_saving()
        +restore_from_saved()
    }
    class HybridQuantizedTensorStorage {
        +_rowwise_storage: QuantizedTensorStorage
        +_columnwise_storage: QuantizedTensorStorage
        +_quantizer: HybridQuantizer
        +update_usage()
        +dequantize()
        +view()
    }
    class HybridQuantizedTensor {
        +fsdp_pre_all_gather()
        +fsdp_post_all_gather()
        +detach() HybridQuantizedTensor
        +__torch_dispatch__()
        +__reduce_ex__()
    }
    class IdentityQuantizer {
        +dtype: Optional~torch.dtype~
        +quantize_impl(tensor) IdentityTensor
    }
    class Float8TensorStorage {
        +fsdp_buffer_fields() direction-aware
        +fsdp_assign_gathered() clears _transpose_invalid
        +view() null-guards _data
    }
    class MXFP8TensorStorage {
        +fsdp_buffer_fields()
        +fsdp_extract_buffers() strips scale padding
        +fsdp_assign_gathered() re-pads scales
    }
    Quantizer <|-- HybridQuantizer
    Quantizer <|-- IdentityQuantizer
    QuantizedTensorStorage <|-- HybridQuantizedTensorStorage
    HybridQuantizedTensorStorage <|-- HybridQuantizedTensor
    HybridQuantizer "1" --> "2" Quantizer : owns rowwise + columnwise
    HybridQuantizedTensorStorage "1" --> "2" QuantizedTensorStorage : rowwise + columnwise sub
    QuantizedTensorStorage <|-- Float8TensorStorage
    QuantizedTensorStorage <|-- MXFP8TensorStorage
Loading

Reviews (14): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.

Comment on lines +52 to +53
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle the case where not all usages are needed? I'd expect something like:

Suggested change
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None
columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None

@negvet negvet May 21, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 4858491

requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not work under FSDP2.

Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated
Comment on lines +1339 to +1355
def factory(role):
if role == "linear_weight":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_mxfp8_quantizer(),
)
if role == "linear_input":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if role in ("linear_grad_output", "linear_grad_input"):
return HybridQuantizer(
rowwise_quantizer=_make_mxfp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
return None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is horrifying. Good test.

negvet and others added 10 commits April 6, 2026 10:26
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py
negvet and others added 2 commits April 29, 2026 16:02
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
negvet added 3 commits May 13, 2026 12:34
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from ksivaman as a code owner May 21, 2026 13:53
Comment on lines 665 to 677
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
shape=(
split_tensor.shape
if split_tensor is not None
else split_transpose_tensor.shape
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 When _data is None (columnwise-only sub-storage of a HybridQuantizedTensor on non-Hopper), the split falls back to split_transpose_tensor.shape, which is the transposed layout's shape [K, M/n]. The correct nominal shape for the shard is [M/n, K]. This wrong nominal shape propagates into the HybridQuantizedTensor through fsdp_post_all_gather (which calls _infer_shape on the gathered _transpose buffer to build col_sub), so after the first FSDP2 iteration the assembled full-parameter hybrid's _columnwise_storage reports [K, M] instead of [M, K]. Any Python-side code that calls .size() on that sub-storage (e.g., HybridQuantizedTensorStorage.size() when rowwise is also None, workspace-validity checks, debugging assertions) will see the wrong dimensions.

Suggested change
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
shape=(
split_tensor.shape
if split_tensor is not None
else split_transpose_tensor.shape
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=(
split_tensor.shape
if split_tensor is not None
# _transpose has shape [K, M/n] but the shard's nominal shape
# is [M/n, K]. Recover the correct shard shape by reversing
# the last two dims of the transposed piece.
else (*split_transpose_tensor.shape[1:], split_transpose_tensor.shape[0])
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be fixed in 8cc3332

Comment on lines +27 to +30
# DCP serializes ``CustomRecipe`` via ``pickle``; closure-based qfactories
# (lambdas, inner functions referencing captured state) are not picklable,
# so the qfactory must live at module scope. See
# ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is potentially useful, but I don't think it is in the right place - shouldn't it be closer to the actual implementation?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +1177 to +1184
for param in model.parameters():
state = optimizer.state[param]
assert state["exp_avg"].dtype == torch.float32
assert state["exp_avg_sq"].dtype == torch.float32
if "master_param" in state:
assert state["master_param"].dtype == torch.float32

assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not a very strict test, is there a way for us to do some numerical correctness comparisons?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enabled check for the monotonic loss decrease (still mostly sanity), and also enabled hybrid vs vanilla bitwise recipe comparizon, see e.g. test_fused_adam_hybrid_vs_base_recipe_parity.

Comment on lines +126 to +131
"""Default NVFP4Quantizer: no RHT, no stochastic rounding, no 2D
scaling — matches upstream ``run_numerics.py::nvfp4_vanilla()`` which
strips the recipe to bare 1D block scaling for distributed TP
fairness. Same-format hybrid NVFP4 has no E5M2 variant (NVFP4 is a
single format family — E2M1 only), so grad roles reuse the same
NVFP4 quantizer."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we want to check the full recipe here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to the full recipe except 1D for weights, will enabled after #3027 merge

Comment on lines +136 to +143
is_linear = role is not None and role.module_type in ("linear", "grouped_linear")
if is_linear and role.tensor_type in ("input", "weight", "output"):
return HybridQuantizer(
rowwise_quantizer=_make_nvfp4_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if is_linear and role.tensor_type in ("grad_output", "grad_input"):
return _make_nvfp4_quantizer()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written those lines are not needed at all. They would be needed if you did the full recipe.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to the full recipe

Comment on lines +166 to +168
# quantization (rowwise and columnwise quantizers run independently, so
# their outputs may differ by ~1 ULP from a single fused-quantize path
# in edge cases).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That does not sound like a good thing if it actually happens in practice - the quantization only should not be affected if you do both at the same time vs one at a time -> the input and the algorithm is the same in both cases. Fusion with the activations could maybe give slightly different results, but I would still like to get an explanation of why that would be.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. If the algorithm is the same, we are indeed getting identical results. New bitwise linear_vs_vanilla test confirms this. The only place where two pass and fused differ is NVFP4 with RHT + SR. This activates a separate columnwise RNG (need_separate_columnwise_rng), and RNG stream consumed differently. see comment in _backward_not_bitwise_comparable(). Removed the comment.

negvet and others added 6 commits June 1, 2026 08:47
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from cyanguwa as a code owner June 10, 2026 16:49
negvet and others added 2 commits June 10, 2026 16:53
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet

negvet commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet

negvet commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

negvet added 2 commits June 12, 2026 13:15
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet

negvet commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@negvet

negvet commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants