[PyTorch][torch.compile] Remove process group from quantizers#3104
[PyTorch][torch.compile] Remove process group from quantizers#3104pggPL wants to merge 5 commits into
Conversation
|
/te-ci pytorch L1 |
Greptile SummaryThis PR refactors amax reduction group handling by moving it out of persistent quantizer state and into per-call parameters, which simplifies
Confidence Score: 5/5Safe to merge. The refactoring correctly preserves all amax-reduction semantics for both the activation quantization path and the FSDP2 weight path. The core invariant — that every quantization call that previously performed an amax all-reduce still does so after the refactor — holds across all three module types and the FSDP2 weight-update path. The only substantive asymmetry is that The dual code paths in Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant Module
participant FwdBwdFn as Forward/Backward Fn
participant Helper as set_quantizer_amax_reduction_group
participant Quantizer
participant texQuantize as tex.quantize (C++)
Note over Module,texQuantize: Per-call group injection (activation path)
Module->>FwdBwdFn: forward(input) / backward(grad)
FwdBwdFn->>Helper: set_quantizer_amax_reduction_group(q, tp_group or None)
Helper->>Quantizer: "q.with_amax_reduction = True/False"
Helper->>Quantizer: "q.amax_reduction_group = group/None"
FwdBwdFn->>Quantizer: quantizer(tensor)
Quantizer->>texQuantize: tex.quantize(src, self, dst)
Note over Module,texQuantize: FSDP2 weight path
Module->>Module: fsdp_pre_all_gather(mesh)
Module->>Module: "tensor.amax_reduction_group = mesh.get_group()"
Module->>Module: "tensor.data = fp32_weight"
Module->>Quantizer: quantizer.copy() to throwaway
Module->>Quantizer: "throwaway.with_amax_reduction = True"
Module->>texQuantize: tex.quantize(src, throwaway, dst)
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant Module
participant FwdBwdFn as Forward/Backward Fn
participant Helper as set_quantizer_amax_reduction_group
participant Quantizer
participant texQuantize as tex.quantize (C++)
Note over Module,texQuantize: Per-call group injection (activation path)
Module->>FwdBwdFn: forward(input) / backward(grad)
FwdBwdFn->>Helper: set_quantizer_amax_reduction_group(q, tp_group or None)
Helper->>Quantizer: "q.with_amax_reduction = True/False"
Helper->>Quantizer: "q.amax_reduction_group = group/None"
FwdBwdFn->>Quantizer: quantizer(tensor)
Quantizer->>texQuantize: tex.quantize(src, self, dst)
Note over Module,texQuantize: FSDP2 weight path
Module->>Module: fsdp_pre_all_gather(mesh)
Module->>Module: "tensor.amax_reduction_group = mesh.get_group()"
Module->>Module: "tensor.data = fp32_weight"
Module->>Quantizer: quantizer.copy() to throwaway
Module->>Quantizer: "throwaway.with_amax_reduction = True"
Module->>texQuantize: tex.quantize(src, throwaway, dst)
Reviews (6): Last reviewed commit: "Drop redundant placeholder comments left..." | Re-trigger Greptile |
| """Quantize tensor""" | ||
| return self.quantize(tensor) | ||
| if amax_reduction_group is None: | ||
| return self.quantize(tensor) | ||
| return self.quantize(tensor, amax_reduction_group=amax_reduction_group) |
There was a problem hiding this comment.
The
None guard here is redundant: self.quantize(tensor) and self.quantize(tensor, amax_reduction_group=None) are identical because quantize defaults the argument to None. The branch just adds noise.
| """Quantize tensor""" | |
| return self.quantize(tensor) | |
| if amax_reduction_group is None: | |
| return self.quantize(tensor) | |
| return self.quantize(tensor, amax_reduction_group=amax_reduction_group) | |
| """Quantize tensor""" | |
| return self.quantize(tensor, amax_reduction_group=amax_reduction_group) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| @property | ||
| def rht_matrix(self) -> torch.Tensor: | ||
| """RHT matrix (fetched from the process-global cache, not stored per quantizer).""" | ||
| return get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device()) |
There was a problem hiding this comment.
Deserialization break for old pickled
NVFP4Quantizer instances
rht_matrix is now a property that reads self._with_random_sign_mask, but _with_random_sign_mask is a new field that did not exist in pickled state produced before this change. When Python's default __setstate__ (i.e., self.__dict__.update(state)) loads an old pickle, _with_random_sign_mask is absent, so any access to the rht_matrix property raises AttributeError. A __setstate__ that infers _with_random_sign_mask from the old stored rht_matrix (or supplies a safe default) would preserve backward compatibility for serialized quantizers.
|
Blocked by FSDP bug, refactor in progress. I plan to store .amax_reduction_group in QuantizedTensor. |
There was a problem hiding this comment.
This would be a design mistake. The amax reduction does not have a consistent meaning across recipes (including recipes where it doesn't make sense), and this change requires spilling out amax reduction logic into quantizer callsites (even where it doesn't make sense).
Can you go into more detail exactly why torch.compile doesn't work when quantizers have process groups? If we just want the quantizer to hold simple Python objects, maybe we can make the quantizer hold an int for the communicator ID. I envision something like:
class Float8CurrentScalingQuantizer(Quantizer):
_communicator_cache = {}
@property
def amax_reduction_group(self):
if self._amax_reduction_group_id is None:
return None
return Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id]
@property.setter
def amax_reduction_group(self, comm):
if comm is None:
self._amax_reduction_group_id = None
self._amax_reduction_group_id = id(comm)
Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id] = commI'm not sure how this would interact with checkpointing though.
| dst: QuantizedTensor, | ||
| *, | ||
| noop_flag: Optional[torch.Tensor] = None, | ||
| amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument |
There was a problem hiding this comment.
I strongly oppose this API change. amax reduction is very recipe-specific. It has different meanings for different recipes (FP8 DS might reduce over the TP+DP group, FP8 CS might only reduce over the TP group) and it has no meaning for other recipes (MXFP8 and FP8 block scaling). Moving it into the generic API will leak recipe-specific information, defeating the point of a generic API.
|
/te-ci pytorch L1 |
… fwd/bwd impls The amax reduction process group is no longer stored persistently on the module quantizer at setup. Two paths, no C++ changes (the cast kernel still reads the group off the quantizer at quantize time): - TP sequence parallel: set the group on the input/grad-output quantizer at point of use in the fwd/bwd impls (linear, layernorm_linear, layernorm_mlp, ops basic_linear) via set_quantizer_amax_reduction_group, replacing the setup-time _customize_quantizers group-setting (and dropping _customize_quantizers_nvfp4). Avoids a shared quantizer's group being clobbered across interleaved passes. - FSDP2: store the group on Float8Tensor/NVFP4Tensor (set in fsdp_pre_all_gather) and bridge it onto the quantizer before the in-place re-quantization in update_quantized / _set_data, so all weight shards stay globally scaled. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
e9097d6 to
948cd6d
Compare
for more information, see https://pre-commit.ci
The amax reduction group must only ride on the QuantizedTensor, never on its quantizer. copy() no longer propagates with_amax_reduction/amax_reduction_group (so every freshly-quantized output tensor gets a clean quantizer), and the in-place FSDP re-quant paths (_set_data, update_quantized for Float8 and NVFP4) apply the group on a throwaway quantizer copy instead of mutating the tensor's own quantizer. Also trim verbose comments. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…oup wiring Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Strip the group off the output tensor's quantizer at the quantize() chokepoint after a fresh quantize, and restore copy() to upstream behavior. This keeps the group off every tensor's quantizer (FSDP weight init, _set_data, TP activations) while leaving copy() — and thus fsdp_post_all_gather / make_like / explicit .copy() — unchanged, narrowing the backward-compat impact. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Description
This makes adding torch.compile support much easier.
Move amax reduction process group handling out of quantizer state and pass it per quantization call instead. This avoids storing process groups inside quantizers while keeping deprecated stored-group fallback behavior for compatibility.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
amax_reduction_groupthroughquantize/module call paths instead of storing it on quantizers.Checklist: