[PyTorch] Split TE ops op_forward into op_forward and setup_context#2877
[PyTorch] Split TE ops op_forward into op_forward and setup_context#2877pggPL wants to merge 6 commits intoNVIDIA:mainfrom
Conversation
|
|
||
| @abc.abstractmethod | ||
| def op_forward( | ||
| def op_forward(self, *args: Any, **kwargs: Any) -> Any: |
There was a problem hiding this comment.
This is trying to use the same function name to support two different behaviors, with two different APIs. This will be hard to reason about and there will be many bugs from using the wrong API. It seems more straightforward to implement a new function:
def op_forward(self, ctx, ...):
"""Forward pass"""
# Some advanced ops may override this default
self.op_forward_compute(...)
self.op_forward_setup_context(ctx)
def op_forward_compute(self, ...):
"""Forward pass compute"""
...
def op_forward_setup_context(self, ctx):
"""Setup context for backward pass"""
...| #: ``setup_tensors`` (a :obj:`FusedOpSetupTensors`, or ``None``). | ||
| separate_fuser_setup_context: ClassVar[bool] = False | ||
|
|
||
| def fuser_setup_context( |
There was a problem hiding this comment.
Right now this requires that the context state is a kwarg, lives in the op, or is a tensor. However, what if we want to save other context state? For example, maybe one of the saved tensors could have multiple possible formats and we need an enum or something to keep track (e.g. what if grouped linear op could save either split sizes or split offsets). My first thought is to allow op_forward to pass a dict or tuple to setup_context, although I'm not sure how well this would work with torch.compile.
x, fused_op_extra_outputs, setup_tensors, context_state = op.fuser_forward(...)
op.fuser_setup_context(ctxs, ..., setup_tensors=setup_tensors, context_state=context_state)If we're lucky, could we get rid of setup_tensors entirely and just store the tensors in a dict?
There was a problem hiding this comment.
This will not work with torch.compile and this is main motivation of that change. If we could do this, we could just return ctx and pass it into the setup_context.
There is a purpose in that. Setup context is run once with fake tensors (not once per iteration), so that the compiler can trace which tensors need to be saved and is able to link outputs from forward to inputs for backward. This helps compiler to decide which fusions to perform and where apply recomputation.
So basically if we are unable to split setup context and forward, we should use old API and it will result in graph break.
Refactor the ops API to support splitting forward passes into two phases: pure computation (op_forward_compute / fuser_forward_compute) and context saving (op_forward_save_ctx / fuser_forward_save_ctx). This separation enables torch.compile compatibility by isolating side effects from computation. - Add fuser_forward_compute and fuser_forward_save_ctx to FusibleOperation with __init_subclass__ validation - Add op_forward_compute and op_forward_save_ctx to BasicOperation with compile-time checks against mixing legacy and split APIs - Add _use_split_forward flag for runtime dispatch in fuser - Convert all basic ops to the split API - Convert all fused ops (ForwardLinearScaleAdd, ForwardLinearBiasAdd, ForwardLinearBiasActivation, UserbuffersForwardLinear, ForwardGroupedMLP) to the split API, delegating save_ctx to basic ops where possible Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
Remove redundant op_forward/op_backward guard methods from AddExtraInput, MakeExtraOutput, GroupedLinear, and _ScaledGLU. The _check_no_extra_io helper already catches these cases. Reset _use_split_forward=False in BasicOperation.__init_subclass__ when a subclass defines op_forward (legacy API), so custom ops don't inherit the split dispatch path from BasicOperation. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
for more information, see https://pre-commit.ci
2efb140 to
796e780
Compare
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
Expose fuser_forward_compute, fuser_forward_save_ctx, op_forward_compute, and op_forward_save_ctx in API documentation. Document the two forward pass APIs (legacy vs split) in class docstrings. Fix type annotations for tensors_to_save to accept QuantizedTensorStorage alongside Tensor. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
Greptile SummaryThis PR splits Confidence Score: 5/5Safe to merge — all remaining findings are P2 style/robustness suggestions that do not affect correctness for any existing code path. The core split-API design is sound and all migrated operations have been verified to preserve the original forward/backward semantics. The two flagged issues are an asymmetric safety check in FusibleOperation.init_subclass (affects only future subclass authors, not existing code) and a duplicated impl-selection block in Dropout (maintenance risk, not a present bug). No data-loss, security, or build issues found. transformer_engine/pytorch/ops/op.py (missing fuser_forward+fuser_forward_compute conflict guard) and transformer_engine/pytorch/ops/basic/dropout.py (duplicated impl logic). Important Files Changed
Sequence DiagramsequenceDiagram
participant Fuser as OperationFuser
participant FC as fuser_forward_compute
participant FS as fuser_forward_save_ctx
participant Ctx as OperationContext
Fuser->>FC: op.fuser_forward_compute(x, **fwd_kwargs)
Note over FC: Pure computation — compilable region
FC-->>Fuser: (output, extra_outputs, tensors_to_save)
Fuser->>FS: op.fuser_forward_save_ctx(ctxs, x, tensors_to_save, **fwd_kwargs)
Note over FS: Side effects — outside compiled region
FS->>Ctx: ctx.save_for_backward(*tensors)
FS->>Ctx: ctx.dtype = ...
FS->>Ctx: ctx.quantizer = ...
FS-->>Fuser: None
Fuser->>Fuser: flatten & save all ctx tensors via func_ctx.save_for_backward
Reviews (2): Last reviewed commit: "Require paired split API methods and wid..." | Re-trigger Greptile |
Add symmetric __init_subclass__ checks: defining op_forward_compute without op_forward_save_ctx (or fuser_forward_compute without fuser_forward_save_ctx) now raises TypeError at class definition time instead of hitting NotImplementedError at runtime. Widen tensors_to_save type annotations from Optional[torch.Tensor] to Optional[Union[torch.Tensor, QuantizedTensorStorage]] across all 13 subclass files to match the base class and allow saving storage objects. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
|
@timmoon10 I've cleaned the code (there was some ai slop i haven't looked into) and opened the PR, I think it is fine now. |
Description
Split the forward pass of fusible operations (
op_forward/fuser_forward) into two separate methods —computeandsave_ctx— to enabletorch.compilecompatibility. Undertorch.compile, the forward graph is traced without side effects, so separating pure computation from context saving allows the compute portion to be compiled while context saving remains outside the compiled region.Type of change: New feature (non-breaking change which adds functionality)
Changes
BasicOperation: Addedop_forward_computeandop_forward_save_ctxas an alternative toop_forward. Subclasses implement either the legacy single-method API or the new split API (defining both is an error).FusibleOperation/FusedOperation: Addedfuser_forward_computeandfuser_forward_save_ctxas an alternative tofuser_forward, with the same auto-detection via__init_subclass__.BasicLinear,Bias,Activation,SwiGLU,LayerNorm,RMSNorm,Dropout,Quantize,AllGather,AllReduce,ReduceScatter,Reshape,Identity,ConstantScale,L2Normalization) to the split API.ForwardLinearBiasAdd,ForwardLinearBiasActivation,ForwardLinearScaleAdd,ForwardGroupedMLP,UserbuffersForwardLinear) to the split API.OperationFuserto detect_use_split_forwardand call the appropriate path, sharingfwd_kwargsbetweencomputeandsave_ctx.tensors_to_savenow correctly typed asOptional[Union[torch.Tensor, QuantizedTensorStorage]]instead ofOptional[torch.Tensor].fuser_forward_compute,fuser_forward_save_ctx,op_forward_compute,op_forward_save_ctxindocs/api/pytorch.rst. Added class-level docstrings documenting the two API choices.Checklist