Skip to content

[PyTorch] Split TE ops op_forward into op_forward and setup_context#2877

Open
pggPL wants to merge 6 commits intoNVIDIA:mainfrom
pggPL:torch_compile_sequential_split_into_setup_ctx_and_forward
Open

[PyTorch] Split TE ops op_forward into op_forward and setup_context#2877
pggPL wants to merge 6 commits intoNVIDIA:mainfrom
pggPL:torch_compile_sequential_split_into_setup_ctx_and_forward

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented Apr 14, 2026

Description

Split the forward pass of fusible operations (op_forward / fuser_forward) into two separate methods — compute and save_ctx — to enable torch.compile compatibility. Under torch.compile, the forward graph is traced without side effects, so separating pure computation from context saving allows the compute portion to be compiled while context saving remains outside the compiled region.

Type of change: New feature (non-breaking change which adds functionality)

Changes

  • Split API for BasicOperation: Added op_forward_compute and op_forward_save_ctx as an alternative to op_forward. Subclasses implement either the legacy single-method API or the new split API (defining both is an error).
  • Split API for FusibleOperation / FusedOperation: Added fuser_forward_compute and fuser_forward_save_ctx as an alternative to fuser_forward, with the same auto-detection via __init_subclass__.
  • Migrated all basic ops (BasicLinear, Bias, Activation, SwiGLU, LayerNorm, RMSNorm, Dropout, Quantize, AllGather, AllReduce, ReduceScatter, Reshape, Identity, ConstantScale, L2Normalization) to the split API.
  • Migrated fused ops (ForwardLinearBiasAdd, ForwardLinearBiasActivation, ForwardLinearScaleAdd, ForwardGroupedMLP, UserbuffersForwardLinear) to the split API.
  • Fuser dispatch: Updated OperationFuser to detect _use_split_forward and call the appropriate path, sharing fwd_kwargs between compute and save_ctx.
  • Type annotation fix: tensors_to_save now correctly typed as Optional[Union[torch.Tensor, QuantizedTensorStorage]] instead of Optional[torch.Tensor].
  • Documentation: Exposed fuser_forward_compute, fuser_forward_save_ctx, op_forward_compute, op_forward_save_ctx in docs/api/pytorch.rst. Added class-level docstrings documenting the two API choices.

Checklist

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Comment thread transformer_engine/pytorch/ops/basic/activation.py Outdated
Comment thread transformer_engine/pytorch/ops/fuser.py Outdated
Comment thread transformer_engine/pytorch/ops/op.py Outdated

@abc.abstractmethod
def op_forward(
def op_forward(self, *args: Any, **kwargs: Any) -> Any:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is trying to use the same function name to support two different behaviors, with two different APIs. This will be hard to reason about and there will be many bugs from using the wrong API. It seems more straightforward to implement a new function:

def op_forward(self, ctx, ...):
    """Forward pass"""

    # Some advanced ops may override this default
    self.op_forward_compute(...)
    self.op_forward_setup_context(ctx)

def op_forward_compute(self,  ...):
    """Forward pass compute"""
    ...

def op_forward_setup_context(self, ctx):
    """Setup context for backward pass"""
    ...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good point.

Comment thread transformer_engine/pytorch/ops/op.py Outdated
#: ``setup_tensors`` (a :obj:`FusedOpSetupTensors`, or ``None``).
separate_fuser_setup_context: ClassVar[bool] = False

def fuser_setup_context(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now this requires that the context state is a kwarg, lives in the op, or is a tensor. However, what if we want to save other context state? For example, maybe one of the saved tensors could have multiple possible formats and we need an enum or something to keep track (e.g. what if grouped linear op could save either split sizes or split offsets). My first thought is to allow op_forward to pass a dict or tuple to setup_context, although I'm not sure how well this would work with torch.compile.

x, fused_op_extra_outputs, setup_tensors, context_state = op.fuser_forward(...)
op.fuser_setup_context(ctxs, ..., setup_tensors=setup_tensors, context_state=context_state)

If we're lucky, could we get rid of setup_tensors entirely and just store the tensors in a dict?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work with torch.compile and this is main motivation of that change. If we could do this, we could just return ctx and pass it into the setup_context.

There is a purpose in that. Setup context is run once with fake tensors (not once per iteration), so that the compiler can trace which tensors need to be saved and is able to link outputs from forward to inputs for backward. This helps compiler to decide which fusions to perform and where apply recomputation.

So basically if we are unable to split setup context and forward, we should use old API and it will result in graph break.

Comment thread transformer_engine/pytorch/ops/op.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py Outdated
pggPL and others added 3 commits April 15, 2026 17:05
Refactor the ops API to support splitting forward passes into two
phases: pure computation (op_forward_compute / fuser_forward_compute)
and context saving (op_forward_save_ctx / fuser_forward_save_ctx).
This separation enables torch.compile compatibility by isolating
side effects from computation.

- Add fuser_forward_compute and fuser_forward_save_ctx to
  FusibleOperation with __init_subclass__ validation
- Add op_forward_compute and op_forward_save_ctx to BasicOperation
  with compile-time checks against mixing legacy and split APIs
- Add _use_split_forward flag for runtime dispatch in fuser
- Convert all basic ops to the split API
- Convert all fused ops (ForwardLinearScaleAdd, ForwardLinearBiasAdd,
  ForwardLinearBiasActivation, UserbuffersForwardLinear,
  ForwardGroupedMLP) to the split API, delegating save_ctx to basic
  ops where possible

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Remove redundant op_forward/op_backward guard methods from
AddExtraInput, MakeExtraOutput, GroupedLinear, and _ScaledGLU.
The _check_no_extra_io helper already catches these cases.

Reset _use_split_forward=False in BasicOperation.__init_subclass__
when a subclass defines op_forward (legacy API), so custom ops
don't inherit the split dispatch path from BasicOperation.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
@pggPL pggPL force-pushed the torch_compile_sequential_split_into_setup_ctx_and_forward branch from 2efb140 to 796e780 Compare April 15, 2026 15:33
pggPL added 2 commits April 15, 2026 17:41
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Expose fuser_forward_compute, fuser_forward_save_ctx, op_forward_compute,
and op_forward_save_ctx in API documentation. Document the two forward
pass APIs (legacy vs split) in class docstrings. Fix type annotations
for tensors_to_save to accept QuantizedTensorStorage alongside Tensor.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
@pggPL pggPL marked this pull request as ready for review April 16, 2026 09:48
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 16, 2026

Greptile Summary

This PR splits op_forward/fuser_forward into op_forward_compute + op_forward_save_ctx (and their fuser counterparts) to enable torch.compile compatibility. All existing basic and fused operations are migrated to the new split API, and the OperationFuser dispatch is updated accordingly. The design is sound and the migration is mechanically consistent across the codebase — the global-state re-queries in save_ctx are intentional for torch.compile compatibility.

Confidence Score: 5/5

Safe to merge — all remaining findings are P2 style/robustness suggestions that do not affect correctness for any existing code path.

The core split-API design is sound and all migrated operations have been verified to preserve the original forward/backward semantics. The two flagged issues are an asymmetric safety check in FusibleOperation.init_subclass (affects only future subclass authors, not existing code) and a duplicated impl-selection block in Dropout (maintenance risk, not a present bug). No data-loss, security, or build issues found.

transformer_engine/pytorch/ops/op.py (missing fuser_forward+fuser_forward_compute conflict guard) and transformer_engine/pytorch/ops/basic/dropout.py (duplicated impl logic).

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/op.py Introduces split API base-class methods and __init_subclass__ hooks; BasicOperation correctly validates the API pair but FusibleOperation omits the fuser_forward+fuser_forward_compute conflict check.
transformer_engine/pytorch/ops/fuser.py Dispatch correctly branches on _use_split_forward; shared fwd_kwargs dict is passed cleanly to both compute and save_ctx.
transformer_engine/pytorch/ops/basic/dropout.py impl selection logic is re-computed independently in op_forward_save_ctx, duplicating the logic from op_forward_compute and risking silent divergence if one branch is changed.
transformer_engine/pytorch/ops/basic/basic_linear.py Clean split; FP8/quantizer re-queries in op_forward_save_ctx are intentional for torch.compile boundary.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py base_offsets is correctly passed through tensors_to_save and unpacked as a non-saved attribute, preserving original backward semantics.
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py CPU offload and backward-override bias quantizer handling correctly delegated to BasicLinear.op_forward_save_ctx and Bias.op_forward_save_ctx.
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py Clean migration; delegates context saving to constituent op save_ctx methods correctly.
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py Split migration follows same pattern as other fused ops; input_dims set directly in save_ctx outside the delegation.

Sequence Diagram

sequenceDiagram
    participant Fuser as OperationFuser
    participant FC as fuser_forward_compute
    participant FS as fuser_forward_save_ctx
    participant Ctx as OperationContext

    Fuser->>FC: op.fuser_forward_compute(x, **fwd_kwargs)
    Note over FC: Pure computation — compilable region
    FC-->>Fuser: (output, extra_outputs, tensors_to_save)

    Fuser->>FS: op.fuser_forward_save_ctx(ctxs, x, tensors_to_save, **fwd_kwargs)
    Note over FS: Side effects — outside compiled region
    FS->>Ctx: ctx.save_for_backward(*tensors)
    FS->>Ctx: ctx.dtype = ...
    FS->>Ctx: ctx.quantizer = ...
    FS-->>Fuser: None

    Fuser->>Fuser: flatten & save all ctx tensors via func_ctx.save_for_backward
Loading

Reviews (2): Last reviewed commit: "Require paired split API methods and wid..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ops/op.py
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Add symmetric __init_subclass__ checks: defining op_forward_compute
without op_forward_save_ctx (or fuser_forward_compute without
fuser_forward_save_ctx) now raises TypeError at class definition time
instead of hitting NotImplementedError at runtime.

Widen tensors_to_save type annotations from Optional[torch.Tensor] to
Optional[Union[torch.Tensor, QuantizedTensorStorage]] across all 13
subclass files to match the base class and allow saving storage objects.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Apr 16, 2026

@timmoon10 I've cleaned the code (there was some ai slop i haven't looked into) and opened the PR, I think it is fine now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants