Skip to content

[PyTorch][torch.compile] Remove process group from quantizers#3104

Open
pggPL wants to merge 5 commits into
NVIDIA:mainfrom
pggPL:remove_process_group_from_quantizers
Open

[PyTorch][torch.compile] Remove process group from quantizers#3104
pggPL wants to merge 5 commits into
NVIDIA:mainfrom
pggPL:remove_process_group_from_quantizers

Conversation

@pggPL

@pggPL pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Description

This makes adding torch.compile support much easier.
Move amax reduction process group handling out of quantizer state and pass it per quantization call instead. This avoids storing process groups inside quantizers while keeping deprecated stored-group fallback behavior for compatibility.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality not to work as expected)
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Pass amax_reduction_group through quantize/module call paths instead of storing it on quantizers.
  • Preserve deprecated constructor/state fallback for existing callers, excluding process groups from serialization.
  • Update FP8/NVFP4/MXFP8/blockwise tensor quantization paths and C++ bindings to resolve reduction groups per call.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@pggPL

pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@greptile-apps

greptile-apps Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR refactors amax reduction group handling by moving it out of persistent quantizer state and into per-call parameters, which simplifies torch.compile compatibility. The quantizer copy() methods are updated to not propagate the group, and a new set_quantizer_amax_reduction_group helper applies it at call sites in forward/backward implementations across all three module types.

  • Per-call group injection: set_quantizer_amax_reduction_group is now called at the top of each forward/backward function for the relevant quantizer (input for column-parallel, grad-output for row-parallel), replacing static setup in _customize_quantizers_nvfp4 and the float8-current-scaling parallel block.
  • FSDP2 weight path: Float8Tensor and NVFP4Tensor gain a class-level amax_reduction_group attribute set in fsdp_pre_all_gather; update_quantized and _set_data read this attribute and create a throwaway quantizer copy when it is non-None, keeping the FSDP2 weight-update path correct without storing the group permanently.
  • Backward compatibility: The deprecated constructor-supplied group is preserved when no tensor-level group is present, and __getstate__ explicitly nulls the group to exclude it from serialization.

Confidence Score: 5/5

Safe to merge. The refactoring correctly preserves all amax-reduction semantics for both the activation quantization path and the FSDP2 weight path.

The core invariant — that every quantization call that previously performed an amax all-reduce still does so after the refactor — holds across all three module types and the FSDP2 weight-update path. The only substantive asymmetry is that _linear_backward is more explicit about resetting the input_quantizer group than the layernorm backward functions, but the layernorm cases are safe because they execute the backward immediately after the forward in the same training step.

The dual code paths in float8_tensor.py (_set_data and update_quantized) both independently implement the FSDP2 group logic; either could diverge in a future edit without breaking the other.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/_common.py Adds set_quantizer_amax_reduction_group helper that unconditionally overwrites the quantizer's with_amax_reduction and amax_reduction_group fields; guards correctly against None quantizer and missing attribute.
transformer_engine/pytorch/module/linear.py Removes static _customize_quantizers_nvfp4; _linear_forward_impl and _linear_backward now call set_quantizer_amax_reduction_group directly. Backward also explicitly sets input_quantizer for column-parallel, which mirrors the forward behaviour and is consistent.
transformer_engine/pytorch/module/layernorm_linear.py Drops _customize_quantizers_nvfp4; forward and backward autograd function inject the group via set_quantizer_amax_reduction_group at the point of use.
transformer_engine/pytorch/module/layernorm_mlp.py Drops _customize_quantizers_nvfp4; backward correctly targets fc2_grad_output_quantizer (= GRAD_OUTPUT2) matching the old static path; fc1_grad_output_quantizer is intentionally left without amax reduction.
transformer_engine/pytorch/ops/basic/basic_linear.py Removes the NVFP4 and float8 amax-reduction setup from _customize_quantizers; forward and backward now call set_quantizer_amax_reduction_group at the three relevant use sites (forward input, backward grad-output, backward wgrad input).
transformer_engine/pytorch/tensor/float8_tensor.py Class-level amax_reduction_group attribute added for FSDP2; fsdp_pre_all_gather sets it on the shard, _set_data and update_quantized create throwaway quantizer copies with the group. The copy() intentionally strips the group. Dual code paths for _set_data vs quantize_() are both correct but warrant attention.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Mirrors Float8 changes: class-level amax_reduction_group on NVFP4Tensor, fsdp_pre_all_gather sets it, update_quantized consumes it via throwaway copy. copy() no longer propagates the group.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Module
    participant FwdBwdFn as Forward/Backward Fn
    participant Helper as set_quantizer_amax_reduction_group
    participant Quantizer
    participant texQuantize as tex.quantize (C++)

    Note over Module,texQuantize: Per-call group injection (activation path)
    Module->>FwdBwdFn: forward(input) / backward(grad)
    FwdBwdFn->>Helper: set_quantizer_amax_reduction_group(q, tp_group or None)
    Helper->>Quantizer: "q.with_amax_reduction = True/False"
    Helper->>Quantizer: "q.amax_reduction_group = group/None"
    FwdBwdFn->>Quantizer: quantizer(tensor)
    Quantizer->>texQuantize: tex.quantize(src, self, dst)

    Note over Module,texQuantize: FSDP2 weight path
    Module->>Module: fsdp_pre_all_gather(mesh)
    Module->>Module: "tensor.amax_reduction_group = mesh.get_group()"
    Module->>Module: "tensor.data = fp32_weight"
    Module->>Quantizer: quantizer.copy() to throwaway
    Module->>Quantizer: "throwaway.with_amax_reduction = True"
    Module->>texQuantize: tex.quantize(src, throwaway, dst)
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Module
    participant FwdBwdFn as Forward/Backward Fn
    participant Helper as set_quantizer_amax_reduction_group
    participant Quantizer
    participant texQuantize as tex.quantize (C++)

    Note over Module,texQuantize: Per-call group injection (activation path)
    Module->>FwdBwdFn: forward(input) / backward(grad)
    FwdBwdFn->>Helper: set_quantizer_amax_reduction_group(q, tp_group or None)
    Helper->>Quantizer: "q.with_amax_reduction = True/False"
    Helper->>Quantizer: "q.amax_reduction_group = group/None"
    FwdBwdFn->>Quantizer: quantizer(tensor)
    Quantizer->>texQuantize: tex.quantize(src, self, dst)

    Note over Module,texQuantize: FSDP2 weight path
    Module->>Module: fsdp_pre_all_gather(mesh)
    Module->>Module: "tensor.amax_reduction_group = mesh.get_group()"
    Module->>Module: "tensor.data = fp32_weight"
    Module->>Quantizer: quantizer.copy() to throwaway
    Module->>Quantizer: "throwaway.with_amax_reduction = True"
    Module->>texQuantize: tex.quantize(src, throwaway, dst)
Loading

Reviews (6): Last reviewed commit: "Drop redundant placeholder comments left..." | Re-trigger Greptile

Comment on lines +326 to +329
"""Quantize tensor"""
return self.quantize(tensor)
if amax_reduction_group is None:
return self.quantize(tensor)
return self.quantize(tensor, amax_reduction_group=amax_reduction_group)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The None guard here is redundant: self.quantize(tensor) and self.quantize(tensor, amax_reduction_group=None) are identical because quantize defaults the argument to None. The branch just adds noise.

Suggested change
"""Quantize tensor"""
return self.quantize(tensor)
if amax_reduction_group is None:
return self.quantize(tensor)
return self.quantize(tensor, amax_reduction_group=amax_reduction_group)
"""Quantize tensor"""
return self.quantize(tensor, amax_reduction_group=amax_reduction_group)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +200 to +203
@property
def rht_matrix(self) -> torch.Tensor:
"""RHT matrix (fetched from the process-global cache, not stored per quantizer)."""
return get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Deserialization break for old pickled NVFP4Quantizer instances

rht_matrix is now a property that reads self._with_random_sign_mask, but _with_random_sign_mask is a new field that did not exist in pickled state produced before this change. When Python's default __setstate__ (i.e., self.__dict__.update(state)) loads an old pickle, _with_random_sign_mask is absent, so any access to the rht_matrix property raises AttributeError. A __setstate__ that infers _with_random_sign_mask from the old stored rht_matrix (or supplies a safe default) would preserve backward compatibility for serialized quantizers.

@pggPL

pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

Blocked by FSDP bug, refactor in progress.

I plan to store .amax_reduction_group in QuantizedTensor.

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a design mistake. The amax reduction does not have a consistent meaning across recipes (including recipes where it doesn't make sense), and this change requires spilling out amax reduction logic into quantizer callsites (even where it doesn't make sense).

Can you go into more detail exactly why torch.compile doesn't work when quantizers have process groups? If we just want the quantizer to hold simple Python objects, maybe we can make the quantizer hold an int for the communicator ID. I envision something like:

class Float8CurrentScalingQuantizer(Quantizer):

    _communicator_cache = {}

    @property
    def amax_reduction_group(self):
        if self._amax_reduction_group_id is None:
            return None
        return Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id]

    @property.setter
    def amax_reduction_group(self, comm):
        if comm is None:
            self._amax_reduction_group_id = None
        self._amax_reduction_group_id = id(comm)
        Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id] = comm

I'm not sure how this would interact with checkpointing though.

dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly oppose this API change. amax reduction is very recipe-specific. It has different meanings for different recipes (FP8 DS might reduce over the TP+DP group, FP8 CS might only reduce over the TP group) and it has no meaning for other recipes (MXFP8 and FP8 block scaling). Moving it into the generic API will leak recipe-specific information, defeating the point of a generic API.

@pggPL

pggPL commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

… fwd/bwd impls

The amax reduction process group is no longer stored persistently on the module
quantizer at setup. Two paths, no C++ changes (the cast kernel still reads the
group off the quantizer at quantize time):

- TP sequence parallel: set the group on the input/grad-output quantizer at point
  of use in the fwd/bwd impls (linear, layernorm_linear, layernorm_mlp, ops
  basic_linear) via set_quantizer_amax_reduction_group, replacing the setup-time
  _customize_quantizers group-setting (and dropping _customize_quantizers_nvfp4).
  Avoids a shared quantizer's group being clobbered across interleaved passes.
- FSDP2: store the group on Float8Tensor/NVFP4Tensor (set in fsdp_pre_all_gather)
  and bridge it onto the quantizer before the in-place re-quantization in
  update_quantized / _set_data, so all weight shards stay globally scaled.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the remove_process_group_from_quantizers branch from e9097d6 to 948cd6d Compare June 16, 2026 12:23
pre-commit-ci Bot and others added 4 commits June 16, 2026 12:24
The amax reduction group must only ride on the QuantizedTensor, never on its
quantizer. copy() no longer propagates with_amax_reduction/amax_reduction_group
(so every freshly-quantized output tensor gets a clean quantizer), and the
in-place FSDP re-quant paths (_set_data, update_quantized for Float8 and NVFP4)
apply the group on a throwaway quantizer copy instead of mutating the tensor's
own quantizer. Also trim verbose comments.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…oup wiring

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Strip the group off the output tensor's quantizer at the quantize() chokepoint
after a fresh quantize, and restore copy() to upstream behavior. This keeps the
group off every tensor's quantizer (FSDP weight init, _set_data, TP activations)
while leaving copy() — and thus fsdp_post_all_gather / make_like / explicit
.copy() — unchanged, narrowing the backward-compat impact.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants