diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 3217d29c3b..29c841c2f3 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -154,13 +154,13 @@ Operation fuser :members: forward .. autoapiclass:: transformer_engine.pytorch.ops.FusibleOperation - :members: fuser_forward, fuser_backward + :members: fuser_forward, fuser_forward_compute, fuser_forward_save_ctx, fuser_backward .. autoapiclass:: transformer_engine.pytorch.ops.BasicOperation - :members: op_forward, op_backward + :members: op_forward, op_forward_compute, op_forward_save_ctx, op_backward .. autoapiclass:: transformer_engine.pytorch.ops.FusedOperation - :members: fuser_forward, fuser_backward + :members: fuser_forward, fuser_forward_compute, fuser_forward_save_ctx, fuser_backward .. autoapifunction:: transformer_engine.pytorch.ops.register_forward_fusion diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 13cb519c19..3d06826f51 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -6,11 +6,12 @@ from __future__ import annotations import abc -from typing import Optional +from typing import Optional, Union import torch import transformer_engine_torch as tex +from ...quantized_tensor import QuantizedTensorStorage from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data @@ -80,13 +81,14 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: """ - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]: # Compute dtype dtype: torch.dtype @@ -109,15 +111,31 @@ def op_forward( input_quantizer.set_usage(rowwise=True, columnwise=False) x = input_quantizer(x) - # Save state for backward pass - if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x) - ctx.save_for_backward(x) - ctx.dtype = dtype - ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + if requires_grad: + return y, (x,) + return y, (None,) - return y + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if not requires_grad: + return + (x,) = tensors_to_save + if is_cpu_offload_enabled(): + mark_activation_offload(x) + ctx.save_for_backward(x) + if torch.is_autocast_enabled(): + ctx.dtype = torch.get_autocast_dtype("cuda") + else: + ctx.dtype = input_.dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/add_extra_input.py b/transformer_engine/pytorch/ops/basic/add_extra_input.py index fc3ca9cade..d0bc441dec 100644 --- a/transformer_engine/pytorch/ops/basic/add_extra_input.py +++ b/transformer_engine/pytorch/ops/basic/add_extra_input.py @@ -42,32 +42,24 @@ def __init__(self, *, in_place: bool = False): super().__init__() self._in_place = in_place - def op_forward(self, *args, **kwargs) -> None: - raise RuntimeError( - "{self.__class__.__name__} operation has " - f"{self.num_extra_inputs} extra tensor inputs " - f"and {self.num_extra_outputs} extra tensor outputs. " - "It overrides `fuser_forward` instead of `op_forward`." - ) - def op_backward(self, *args, **kwargs) -> None: raise RuntimeError( - "{self.__class__.__name__} operation has " + f"{self.__class__.__name__} operation has " f"{self.num_extra_inputs} extra tensor inputs " f"and {self.num_extra_outputs} extra tensor outputs. " "It overrides `fuser_backward` instead of `op_backward`." ) - def fuser_forward( + def fuser_forward_compute( self, - basic_op_ctxs: list[OperationContext], input_: torch.Tensor, *, + requires_grad: list[bool], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]], list[tuple[()]]]: extra_input = basic_op_extra_inputs[0][0] if self._in_place: extra_input = extra_input.detach() @@ -75,7 +67,21 @@ def fuser_forward( output = extra_input else: output = extra_input + input_ - return output, [()] + return output, [()], [()] + + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + tensors_to_save: list[tuple[()]], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> None: + pass def fuser_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/all_gather.py b/transformer_engine/pytorch/ops/basic/all_gather.py index 4e5c192876..c886d692d0 100644 --- a/transformer_engine/pytorch/ops/basic/all_gather.py +++ b/transformer_engine/pytorch/ops/basic/all_gather.py @@ -36,19 +36,32 @@ def __init__( self.process_group: Optional[torch.distributed.ProcessGroup] = process_group self.process_group_size: int = torch.distributed.get_world_size(process_group) - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[()]]: out: torch.Tensor if self.process_group_size == 1: out = input_.detach() else: out, _ = gather_along_first_dim(input_, self.process_group) - return out + return out, () + + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[()], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + pass def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/all_reduce.py b/transformer_engine/pytorch/ops/basic/all_reduce.py index f2e4b2481d..aefe044ccd 100644 --- a/transformer_engine/pytorch/ops/basic/all_reduce.py +++ b/transformer_engine/pytorch/ops/basic/all_reduce.py @@ -38,22 +38,35 @@ def __init__( self.process_group: Optional[torch.distributed.ProcessGroup] = process_group self._reduce_in_backward: bool = reduce_in_backward - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[()]]: # Trivial case if torch.distributed.get_world_size(self.process_group) == 1: - return input_ + return input_, () # Perform all-reduce x = maybe_dequantize(input_.contiguous()) torch.distributed.all_reduce(x, group=self.process_group) - return x + return x, () + + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[()], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + pass def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 17594726cc..af57040302 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -8,10 +8,11 @@ from collections.abc import Callable, Iterable import contextlib import math -from typing import Any, Optional +from typing import Any, Optional, Union import torch +from ...quantized_tensor import QuantizedTensorStorage from ...cpp_extensions import general_gemm from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import ( @@ -979,24 +980,23 @@ def _functional_backward( _wait_async(dx_async) return dx, dw - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]: # Check which grads are required - input_requires_grad = ctx.requires_grad - weight_requires_grad = ctx.requires_grad and self.weight.requires_grad + input_requires_grad = requires_grad + weight_requires_grad = requires_grad and self.weight.requires_grad # Quantizers input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) output_quantizer = next_op_input_quantizer - grad_output_quantizer = self.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override @@ -1026,28 +1026,56 @@ def op_forward( weight_requires_grad=weight_requires_grad, ) - # Save state for backward pass - if ctx.requires_grad: + # Determine tensors to save for backward pass + if requires_grad: if backward_override == "high_precision": saved_input = input_ if weight_requires_grad else None saved_weight = self.weight if input_requires_grad else None else: saved_input = x_local saved_weight = w - if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) - ctx.save_for_backward(saved_input, saved_weight) - ctx.with_quantized_compute = with_quantized_compute and backward_override is None - ctx.backward_override = backward_override - ctx.input_quantizer = input_quantizer - ctx.weight_quantizer = weight_quantizer - ctx.grad_output_quantizer = grad_output_quantizer - ctx.grad_input_quantizer = grad_input_quantizer - ctx.dtype = dtype - ctx.input_requires_grad = input_requires_grad - ctx.weight_requires_grad = weight_requires_grad - - return output + else: + saved_input = None + saved_weight = None + + return output, (saved_input, saved_weight) + + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if not requires_grad: + return + + saved_input, saved_weight = tensors_to_save + if is_cpu_offload_enabled(): + mark_activation_offload(saved_input) + ctx.save_for_backward(saved_input, saved_weight) + + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override + else: + backward_override = None + + ctx.with_quantized_compute = with_quantized_compute and backward_override is None + ctx.backward_override = backward_override + ctx.input_quantizer = self.get_quantizer("forward", 0) + ctx.weight_quantizer = self.get_quantizer("forward", 1) + ctx.grad_output_quantizer = self.get_quantizer("backward", 0) + ctx.grad_input_quantizer = prev_op_grad_output_quantizer + if torch.is_autocast_enabled(): + ctx.dtype = torch.get_autocast_dtype("cuda") + else: + ctx.dtype = self.weight.dtype + ctx.input_requires_grad = requires_grad + ctx.weight_requires_grad = requires_grad and self.weight.requires_grad def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 88f563b2c5..11efc7eb3b 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -113,24 +113,35 @@ def pre_first_fuser_forward(self) -> None: if self.bias.device.type == "meta": self.reset_parameters() - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[()]]: x = input_ b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) + return x + b, () - if ctx.requires_grad: - ctx.grad_input_quantizer = prev_op_grad_output_quantizer - if FP8GlobalStateManager.is_fp8_enabled(): - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_recipe.backward_override is not None: - ctx.grad_input_quantizer = None - - return x + b + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[()], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if not requires_grad: + return + ctx.grad_input_quantizer = prev_op_grad_output_quantizer + if FP8GlobalStateManager.is_fp8_enabled(): + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_override is not None: + ctx.grad_input_quantizer = None def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/constant_scale.py b/transformer_engine/pytorch/ops/basic/constant_scale.py index d4b3660acf..99a49341e3 100644 --- a/transformer_engine/pytorch/ops/basic/constant_scale.py +++ b/transformer_engine/pytorch/ops/basic/constant_scale.py @@ -23,14 +23,27 @@ def __init__(self, scale: float) -> None: super().__init__() self.scale = scale - def op_forward( + def op_forward_compute( + self, + input_: torch.Tensor, + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[()]]: + return input_ * self.scale, () + + def op_forward_save_ctx( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: - return input_ * self.scale + tensors_to_save: tuple[()], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + pass def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index 8850604aad..2cb4de0c07 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -5,10 +5,11 @@ """Fusible operation for dropout.""" from __future__ import annotations -from typing import Optional +from typing import Optional, Union import torch import transformer_engine_torch as tex +from ...quantized_tensor import QuantizedTensorStorage from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Quantizer from ...tensor.storage.float8_tensor_storage import Float8TensorStorage @@ -29,13 +30,14 @@ def __init__(self, p: float) -> None: super().__init__() self.dropout_probability: float = p - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]: # Output dtype dtype = maybe_autocast_dtype(default_dtype=input_.dtype) @@ -69,16 +71,34 @@ def op_forward( else: raise ValueError(f"Unsupported forward implementation {impl}") - # Save context for backward - if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(mask) - ctx.save_for_backward(mask) - ctx.impl = impl - ctx.dropout_probability = self.dropout_probability - ctx.dtype = dtype + return out, (mask,) + + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if not requires_grad: + return + (mask,) = tensors_to_save + if is_cpu_offload_enabled(): + mark_activation_offload(mask) + ctx.save_for_backward(mask) - return out + dtype = maybe_autocast_dtype(default_dtype=input_.dtype) + if not self.training: + ctx.impl = "evaluation" + elif input_.numel() % 16 == 0 and dtype in (torch.float16, torch.bfloat16): + ctx.impl = "fused" + else: + ctx.impl = "unfused" + ctx.dropout_probability = self.dropout_probability + ctx.dtype = dtype def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index e21625276c..15e3033df0 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -9,11 +9,12 @@ import contextlib import functools import math -from typing import Any, Optional +from typing import Any, Optional, Union import torch import transformer_engine_torch as tex +from ...quantized_tensor import QuantizedTensorStorage from ...cpp_extensions import general_grouped_gemm from ...distributed import CudaRNGStatesTracker from ...module._common import WeightGradStore @@ -657,14 +658,6 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: recipe.fp8_quant_bwd_grad.amax_epsilon ) - def op_forward(self, *args, **kwargs): - raise RuntimeError( - f"{self.__class__.__name__} operation has " - f"{self.num_extra_inputs} extra tensor inputs " - f"and {self.num_extra_outputs} extra tensor outputs. " - "It overrides `fuser_forward` instead of `op_forward`." - ) - def op_backward(self, *args, **kwargs): raise RuntimeError( f"{self.__class__.__name__} operation has " @@ -673,16 +666,20 @@ def op_backward(self, *args, **kwargs): "It overrides `fuser_backward` instead of `op_backward`." ) - def fuser_forward( + def fuser_forward_compute( self, - basic_op_ctxs: list[OperationContext], input_: torch.Tensor, *, + requires_grad: list[bool], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[torch.Tensor]], + list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + ]: num_groups = self.num_groups has_bias = self.has_bias weight_param = self.weight if self.single_grouped_weight else self.weight0 @@ -695,20 +692,17 @@ def fuser_forward( raise RuntimeError("MAIN GRAD IS NONE") # Check which grads are required - ctx = basic_op_ctxs[0] - input_requires_grad = ctx.requires_grad - weight_requires_grad = ctx.requires_grad and weight_param.requires_grad + input_requires_grad = requires_grad[0] + weight_requires_grad = requires_grad[0] and weight_param.requires_grad # Quantizers input_quantizers = [None] * num_groups weight_quantizers = [None] * num_groups - grad_output_quantizers = [None] * num_groups with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: for group_idx in range(num_groups): input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx) weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1) - grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -812,24 +806,60 @@ def fuser_forward( for x in xs: x.update_usage(rowwise_usage=False, columnwise_usage=True) - # Save state for backward pass - if ctx.requires_grad: + # Build tensors to save for backward pass + if requires_grad[0]: saved = [split_sizes] if self._scale_bias: saved.append(scales) saved.extend(xs) saved.extend(ws) - ctx.save_for_backward(*saved) - ctx.with_quantized_compute = with_quantized_compute - ctx.input_quantizers = input_quantizers - ctx.weight_quantizers = weight_quantizers - ctx.grad_output_quantizers = grad_output_quantizers - ctx.grad_input_quantizers = None - ctx.dtype = dtype - ctx.input_requires_grad = input_requires_grad - ctx.weight_requires_grad = weight_requires_grad - - return out, [()] + tensors_to_save = [tuple(saved)] + else: + tensors_to_save = [()] + + return out, [()], tensors_to_save + + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + tensors_to_save: list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> None: + if not requires_grad[0]: + return + ctx = basic_op_ctxs[0] + ctx.save_for_backward(*tensors_to_save[0]) + + num_groups = self.num_groups + weight_param = self.weight if self.single_grouped_weight else self.weight0 + + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizers = [None] * num_groups + weight_quantizers = [None] * num_groups + grad_output_quantizers = [None] * num_groups + if with_quantized_compute: + for group_idx in range(num_groups): + input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx) + weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx) + + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizers = input_quantizers + ctx.weight_quantizers = weight_quantizers + ctx.grad_output_quantizers = grad_output_quantizers + ctx.grad_input_quantizers = None + if torch.is_autocast_enabled(): + ctx.dtype = torch.get_autocast_dtype("cuda") + else: + ctx.dtype = weight_param.dtype + ctx.input_requires_grad = requires_grad[0] + ctx.weight_requires_grad = requires_grad[0] and weight_param.requires_grad def fuser_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/identity.py b/transformer_engine/pytorch/ops/basic/identity.py index 9e90bd98c0..4d346eb766 100644 --- a/transformer_engine/pytorch/ops/basic/identity.py +++ b/transformer_engine/pytorch/ops/basic/identity.py @@ -19,14 +19,27 @@ class Identity(BasicOperation): """Return input tensor""" - def op_forward( + def op_forward_compute( + self, + input_: torch.Tensor, + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[()]]: + return input_, () + + def op_forward_save_ctx( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: - return input_ + tensors_to_save: tuple[()], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + pass def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index be155c9356..257efa7b69 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -5,11 +5,12 @@ """Fusable operation for L2 Normalization.""" from __future__ import annotations -from typing import Optional +from typing import Optional, Union import os import torch +from ...quantized_tensor import QuantizedTensorStorage from ...torch_version import torch_version from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...jit import ( @@ -77,36 +78,38 @@ def __init__( for hidden_size in common_hidden_sizes: warmup_jit_l2normalization_all_dtypes(hidden_size, seq_length, micro_batch_size) - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: - # Use input directly - torch.compile can handle multi-dimensional tensors + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]: x = maybe_dequantize(input_) - # Check if backward pass is needed - requires_grad = ctx.requires_grad - - # Compute L2 normalization using fused implementation - # L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps) if requires_grad: - # Training: use version that returns output and intermediate values for backward pass y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps) - else: - # Inference: use lightweight version that only returns output - y = l2normalization_fused(x, self.eps) - rsqrt_norm = None # Not needed for inference - - # Save state for backward pass - if requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x, rsqrt_norm) - ctx.save_for_backward(x, rsqrt_norm) + return y, (x, rsqrt_norm) + y = l2normalization_fused(x, self.eps) + return y, (None, None) - return y + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if not requires_grad: + return + x, rsqrt_norm = tensors_to_save + if is_cpu_offload_enabled(): + mark_activation_offload(x, rsqrt_norm) + ctx.save_for_backward(x, rsqrt_norm) def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 3fda5145c6..b4688c8f69 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -8,11 +8,12 @@ from collections.abc import Iterable import math import os -from typing import Optional +from typing import Optional, Union import torch from transformer_engine_torch import layernorm_bwd, layernorm_fwd +from ...quantized_tensor import QuantizedTensorStorage from ...constants import TE_DType from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...export import is_in_onnx_export_mode @@ -173,15 +174,16 @@ def pre_first_fuser_forward(self) -> None: if self.weight.device.type == "meta" or self.bias.device.type == "meta": self.reset_parameters() - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]: if is_in_onnx_export_mode(): - return self.op_onnx_forward(input_) + return self.op_onnx_forward(input_), () # Check tensor dims weight = self.weight @@ -201,7 +203,7 @@ def op_forward( b = maybe_dequantize(self.bias, dtype).view((inner_dim,)) # Compute layer norm - sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"] + sm_margin = self._sm_margins["forward" if requires_grad else "inference"] y, means, rstdevs = layernorm_fwd( x, w, @@ -214,16 +216,30 @@ def op_forward( self.zero_centered_gamma, ) - # Save state for backward pass - if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x, means, rstdevs) - ctx.save_for_backward(x, means, rstdevs) - ctx.dtype = dtype - # Reshape output tensor out = y.view(input_dims) - return out + + if requires_grad: + return out, (x, means, rstdevs) + return out, (None, None, None) + + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if not requires_grad: + return + x, means, rstdevs = tensors_to_save + if is_cpu_offload_enabled(): + mark_activation_offload(x, means, rstdevs) + ctx.save_for_backward(x, means, rstdevs) + ctx.dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py index 0d9c870262..c99b3cc192 100644 --- a/transformer_engine/pytorch/ops/basic/make_extra_output.py +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -47,33 +47,39 @@ def __init__(self, *, in_place: bool = False): super().__init__() self._in_place: bool = in_place - def op_forward(self, *args, **kwargs) -> None: - raise RuntimeError( - "{self.__class__.__name__} operation has " - f"{self.num_extra_inputs} extra tensor inputs " - f"and {self.num_extra_outputs} extra tensor outputs. " - "It overrides `fuser_forward` instead of `op_forward`." - ) - def op_backward(self, *args, **kwargs) -> None: raise RuntimeError( - "{self.__class__.__name__} operation has " + f"{self.__class__.__name__} operation has " f"{self.num_extra_inputs} extra tensor inputs " f"and {self.num_extra_outputs} extra tensor outputs. " "It overrides `fuser_backward` instead of `op_backward`." ) - def fuser_forward( + def fuser_forward_compute( + self, + input_: torch.Tensor, + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]], list[tuple[()]]]: + return input_, [(input_,)], [()] + + def fuser_forward_save_ctx( self, basic_op_ctxs: list[OperationContext], input_: torch.Tensor, + tensors_to_save: list[tuple[()]], *, + requires_grad: list[bool], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: - return input_, [(input_,)] + ) -> None: + pass def fuser_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index d0c1137d91..b9e0a8712a 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -46,32 +46,44 @@ def num_quantizers(self, mode: str) -> int: return 1 return 0 - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[()]]: # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantize_forward = fp8_enabled and self._quantize_forward - quantize_backward = fp8_enabled and self._quantize_backward - - # Backward quantization is controlled by recipe backward override. - if fp8_enabled: - recipe = FP8GlobalStateManager.get_fp8_recipe() - quantize_backward = quantize_backward and recipe.backward_override is None # Quantize if needed out = input_ if quantize_forward and not is_quantized_tensor(out): out = self.get_quantizer("forward", 0)(out) - if ctx.requires_grad: - ctx.quantize_backward = quantize_backward - return out + return out, () + + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[()], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if not requires_grad: + return + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + quantize_backward = fp8_enabled and self._quantize_backward + if fp8_enabled: + recipe = FP8GlobalStateManager.get_fp8_recipe() + quantize_backward = quantize_backward and recipe.backward_override is None + ctx.quantize_backward = quantize_backward def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index 0169da2490..a8a6aa9d92 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -36,17 +36,18 @@ def __init__( self.process_group: Optional[torch.distributed.ProcessGroup] = process_group self.process_group_size: int = torch.distributed.get_world_size(process_group) - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[()]]: # Trivial case if self.process_group_size == 1: - return input_.detach() + return input_.detach(), () # Tensor dimensions input_dims = input_.size() @@ -65,7 +66,19 @@ def op_forward( # Perform reduce-scatter y = torch.empty(output_dims, dtype=x.dtype, device=x.device) torch.distributed.reduce_scatter_tensor(y, x, group=self.process_group) - return y + return y, () + + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[()], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + pass def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index 4a171c294b..703af004c3 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -34,16 +34,28 @@ def __init__(self, shape: Iterable[int]) -> None: super().__init__() self._shape = tuple(shape) - def op_forward( + def op_forward_compute( + self, + input_: torch.Tensor, + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[()]]: + return input_.reshape(*self._shape), () + + def op_forward_save_ctx( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: - if ctx.requires_grad: + tensors_to_save: tuple[()], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if requires_grad: ctx.input_shape = input_.size() - return input_.reshape(*self._shape) def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 1d8d8be971..13997b0c51 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -8,11 +8,12 @@ from collections.abc import Iterable import math import os -from typing import Optional +from typing import Optional, Union import torch from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd +from ...quantized_tensor import QuantizedTensorStorage from ...constants import TE_DType from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...export import is_in_onnx_export_mode @@ -156,15 +157,16 @@ def pre_first_fuser_forward(self) -> None: if self.weight.device.type == "meta": self.reset_parameters() - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]: if is_in_onnx_export_mode(): - return self.op_onnx_forward(input_) + return self.op_onnx_forward(input_), () # Check tensor dims weight = self.weight @@ -183,7 +185,7 @@ def op_forward( w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) # Compute RMSNorm - sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"] + sm_margin = self._sm_margins["forward" if requires_grad else "inference"] y, _, rstdevs = rmsnorm_fwd( x, w, @@ -195,16 +197,30 @@ def op_forward( self.zero_centered_gamma, ) - # Save state for backward pass - if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x, rstdevs) - ctx.save_for_backward(x, rstdevs) - ctx.dtype = dtype - # Reshape output tensor out = y.view(input_dims) - return out + + if requires_grad: + return out, (x, rstdevs) + return out, (None, None) + + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if not requires_grad: + return + x, rstdevs = tensors_to_save + if is_cpu_offload_enabled(): + mark_activation_offload(x, rstdevs) + ctx.save_for_backward(x, rstdevs) + ctx.dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) def op_backward( self, diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index 9c0bc86bc1..6c74c93930 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -6,11 +6,12 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any, Optional +from typing import Any, Optional, Union import torch import transformer_engine_torch as tex +from ...quantized_tensor import QuantizedTensorStorage from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data @@ -79,13 +80,14 @@ def __init__( self.cache_quantized_input: bool = cache_quantized_input self.glu_interleave_size: Optional[int] = glu_interleave_size - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]: # Compute dtype dtype: torch.dtype @@ -97,10 +99,10 @@ def op_forward( raise RuntimeError(f"Unsupported dtype ({dtype})") # Check input tensor - input_ = maybe_dequantize(input_.contiguous(), dtype) + x = maybe_dequantize(input_.contiguous(), dtype) # Remove interleaving if needed - swiglu_in = input_ + swiglu_in = x if self.glu_interleave_size is not None: shape = swiglu_in.size() swiglu_in = swiglu_in.reshape( @@ -119,20 +121,36 @@ def op_forward( if self.cache_quantized_input: input_quantizer = Float8CurrentScalingQuantizer( tex.DType.kFloat8E4M3, - input_.device, + x.device, ) input_quantizer.set_usage(rowwise=True, columnwise=False) - input_ = input_quantizer(input_) + x = input_quantizer(x) - # Save state for backward pass - if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(input_) - ctx.save_for_backward(input_) - ctx.dtype = dtype - ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + if requires_grad: + return out, (x,) + return out, (None,) - return out + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if not requires_grad: + return + (x,) = tensors_to_save + if is_cpu_offload_enabled(): + mark_activation_offload(x) + ctx.save_for_backward(x) + if torch.is_autocast_enabled(): + ctx.dtype = torch.get_autocast_dtype("cuda") + else: + ctx.dtype = input_.dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer def op_backward( self, @@ -259,13 +277,14 @@ def _tex_clamped_dswiglu( self.alpha, ) - def op_forward( + def op_forward_compute( self, - ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - ) -> torch.Tensor: + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]: # Compute dtype dtype: torch.dtype @@ -301,15 +320,31 @@ def op_forward( input_quantizer.set_usage(rowwise=True, columnwise=False) x = input_quantizer(x) - # Save state for backward pass - if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x) - ctx.save_for_backward(x) - ctx.dtype = dtype - ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + if requires_grad: + return out, (x,) + return out, (None,) - return out + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + ) -> None: + if not requires_grad: + return + (x,) = tensors_to_save + if is_cpu_offload_enabled(): + mark_activation_offload(x) + ctx.save_for_backward(x) + if torch.is_autocast_enabled(): + ctx.dtype = torch.get_autocast_dtype("cuda") + else: + ctx.dtype = input_.dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer def op_backward( self, @@ -383,14 +418,6 @@ def _glu_backward( ) -> torch.Tensor: raise NotImplementedError - def op_forward(self, *args, **kwargs) -> None: - raise RuntimeError( - f"{self.__class__.__name__} operation has " - f"{self.num_extra_inputs} extra tensor inputs " - f"and {self.num_extra_outputs} extra tensor outputs. " - "It overrides `fuser_forward` instead of `op_forward`." - ) - def op_backward(self, *args, **kwargs) -> None: raise RuntimeError( f"{self.__class__.__name__} operation has " @@ -399,16 +426,20 @@ def op_backward(self, *args, **kwargs) -> None: "It overrides `fuser_backward` instead of `op_backward`." ) - def fuser_forward( + def fuser_forward_compute( self, - basic_op_ctxs: list[OperationContext], input_: torch.Tensor, *, + requires_grad: list[bool], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[torch.Tensor]], + list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + ]: extra_input = basic_op_extra_inputs[0][0] # Determine compute dtype @@ -439,20 +470,40 @@ def fuser_forward( swiglu_out = self._glu_forward(swiglu_in) out = swiglu_out * scales.unsqueeze(-1) - # Save state for backward pass - ctx = basic_op_ctxs[0] - if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(input_) - ctx.input_requires_grad = True - ctx.extra_input_requires_grad = extra_input.requires_grad - ctx.dtype = dtype - ctx.save_for_backward( - input_, - scales if ctx.input_requires_grad else None, - ) + if requires_grad[0]: + tensors_to_save = [(input_, scales)] + else: + tensors_to_save = [()] + + return out, [()], tensors_to_save - return out, [()] + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + tensors_to_save: list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> None: + if not requires_grad[0]: + return + ctx = basic_op_ctxs[0] + saved_input = tensors_to_save[0][0] + if is_cpu_offload_enabled(): + mark_activation_offload(saved_input) + ctx.save_for_backward(*tensors_to_save[0]) + ctx.input_requires_grad = True + ctx.extra_input_requires_grad = basic_op_extra_inputs[0][0].requires_grad + if torch.is_autocast_enabled(): + ctx.dtype = torch.get_autocast_dtype("cuda") + elif isinstance(input_, torch.Tensor): + ctx.dtype = input_.dtype + else: + ctx.dtype = basic_op_extra_inputs[0][0].dtype def fuser_backward( self, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 4e756ea531..2b222ebbbc 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -9,11 +9,12 @@ import functools import inspect import os -from typing import Any, Optional +from typing import Any, Optional, Union import torch import transformer_engine_torch as tex +from ...quantized_tensor import QuantizedTensorStorage from ...quantization import Recipe from ...tensor import Quantizer from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor @@ -135,19 +136,22 @@ def __init__( # The act_func string should be fixed on the cuDNN FE side. self._cudnn_act_func: str = "geglu" if isinstance(swiglu, ScaledClampedQGeGLU) else "swiglu" - def fuser_forward( + def fuser_forward_compute( self, - basic_op_ctxs: list[OperationContext], input_: torch.Tensor, *, + requires_grad: list[bool], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[torch.Tensor]], + list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + ]: # Get basic operations fc1_op, _, fc2_op = self.basic_ops - fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs # Tensor properties fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) @@ -166,19 +170,17 @@ def fuser_forward( dtype = fc1_weight_param.dtype # Check which grads are required - requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) - input_requires_grad = requires_grad - weight_requires_grad = requires_grad and ( + any_requires_grad = any(requires_grad) + input_requires_grad = any_requires_grad + weight_requires_grad = any_requires_grad and ( fc1_weight_param.requires_grad or fc2_weight_param.requires_grad ) # Quantizers fc1_input_quantizer = fc1_op.get_quantizer("forward", 0) fc1_weight_quantizer = fc1_op.get_quantizer("forward", 1) - fc1_grad_output_quantizer = fc1_op.get_quantizer("backward", 0) fc2_input_quantizer = fc2_op.get_quantizer("forward", 0) fc2_weight_quantizer = fc2_op.get_quantizer("forward", 1) - fc2_grad_output_quantizer = fc2_op.get_quantizer("backward", 0) # Extract split sizes from extra input fc1_split_sizes = basic_op_extra_inputs[0][0] @@ -491,38 +493,27 @@ def fuser_forward( fc2_kernel_out = self.grouped_gemm_quant_kernel()(**fc2_quant_kwargs) fc2_out = fc2_kernel_out["d_tensor"].permute(2, 0, 1).view(fc2_out_shape).contiguous() - # Save state for backward pass - if requires_grad: + # Prepare tensors for backward pass + if any_requires_grad: mark_grouped_tensor(grouped_fc1_x, swiglu_in, scales, grouped_fc2_x) fc1_input_tensors = ( grouped_fc1_x.columnwise_data, grouped_fc1_x.columnwise_scale_inv, fc1_x_tensor_offsets, ) - # FC1 fc1_weight_tensors = ( [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight ) - fc1_ctx.save_for_backward( - split_sizes, split_points, *fc1_weight_tensors, *fc1_input_tensors - ) - fc1_ctx.with_quantized_compute = True - fc1_ctx.input_quantizer = fc1_input_quantizer - fc1_ctx.weight_quantizer = fc1_weight_quantizer - fc1_ctx.grad_output_quantizer = fc1_grad_output_quantizer - fc1_ctx.grad_input_quantizers = None - fc1_ctx.dtype = dtype - fc1_ctx.input_requires_grad = input_requires_grad - fc1_ctx.weight_requires_grad = weight_requires_grad - fc1_ctx.base_split_offsets = base_offsets - - # Scaled SwiGLU - swiglu_ctx.save_for_backward(swiglu_in, scales) - swiglu_ctx.input_requires_grad = True - swiglu_ctx.extra_input_requires_grad = True - swiglu_ctx.dtype = dtype - - # FC2 state + fc1_saved = [ + split_sizes, + split_points, + *fc1_weight_tensors, + *fc1_input_tensors, + base_offsets, + ] + + swiglu_saved = (swiglu_in, scales) + if grouped_fc2_x is not None: fc2_input_tensors = ( grouped_fc2_x.columnwise_data, @@ -531,22 +522,76 @@ def fuser_forward( ) else: fc2_input_tensors = (None, None, None) - + fc2_saved = [split_sizes] if fc2_op.single_grouped_weight: - fc2_ctx.save_for_backward(split_sizes, grouped_fc2_weight, *fc2_input_tensors) + fc2_saved.append(grouped_fc2_weight) else: - fc2_ctx.save_for_backward(split_sizes, *grouped_fc2_weight, *fc2_input_tensors) - - fc2_ctx.with_quantized_compute = True - fc2_ctx.input_quantizer = fc2_input_quantizer - fc2_ctx.weight_quantizer = fc2_weight_quantizer - fc2_ctx.grad_output_quantizer = fc2_grad_output_quantizer - fc2_ctx.grad_input_quantizers = None - fc2_ctx.dtype = dtype - fc2_ctx.input_requires_grad = input_requires_grad - fc2_ctx.weight_requires_grad = weight_requires_grad - - return fc2_out, [(), (), ()] + fc2_saved.extend(grouped_fc2_weight) + fc2_saved.extend(fc2_input_tensors) + + tensors_to_save = [tuple(fc1_saved), swiglu_saved, tuple(fc2_saved)] + else: + tensors_to_save = [(), (), ()] + + return fc2_out, [(), (), ()], tensors_to_save + + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + tensors_to_save: list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> None: + if not any(requires_grad): + return + + fc1_op, _, fc2_op = self.basic_ops + fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 + fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 + input_requires_grad = True + weight_requires_grad = fc1_weight_param.requires_grad or fc2_weight_param.requires_grad + + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = fc1_weight_param.dtype + + # FC1 context + *fc1_bwd_tensors, base_offsets = tensors_to_save[0] + fc1_ctx.save_for_backward(*fc1_bwd_tensors) + fc1_ctx.with_quantized_compute = True + fc1_ctx.input_quantizer = fc1_op.get_quantizer("forward", 0) + fc1_ctx.weight_quantizer = fc1_op.get_quantizer("forward", 1) + fc1_ctx.grad_output_quantizer = fc1_op.get_quantizer("backward", 0) + fc1_ctx.grad_input_quantizers = None + fc1_ctx.dtype = dtype + fc1_ctx.input_requires_grad = input_requires_grad + fc1_ctx.weight_requires_grad = weight_requires_grad + fc1_ctx.base_split_offsets = base_offsets + + # Scaled SwiGLU context + swiglu_ctx.save_for_backward(*tensors_to_save[1]) + swiglu_ctx.input_requires_grad = True + swiglu_ctx.extra_input_requires_grad = True + swiglu_ctx.dtype = dtype + + # FC2 context + fc2_ctx.save_for_backward(*tensors_to_save[2]) + fc2_ctx.with_quantized_compute = True + fc2_ctx.input_quantizer = fc2_op.get_quantizer("forward", 0) + fc2_ctx.weight_quantizer = fc2_op.get_quantizer("forward", 1) + fc2_ctx.grad_output_quantizer = fc2_op.get_quantizer("backward", 0) + fc2_ctx.grad_input_quantizers = None + fc2_ctx.dtype = dtype + fc2_ctx.input_requires_grad = input_requires_grad + fc2_ctx.weight_requires_grad = weight_requires_grad def fuse_forward_ops( diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 8df929f799..6469c2419a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -6,11 +6,11 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any, Optional +from typing import Any, Optional, Union import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...quantized_tensor import QuantizedTensorStorage from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import BasicLinear, Bias @@ -50,47 +50,41 @@ def __init__( # Index of each basic operations self._op_idxs: dict[str, Optional[int]] = op_idxs - def fuser_forward( + def fuser_forward_compute( self, - basic_op_ctxs: list[OperationContext], input_: torch.Tensor, *, + requires_grad: list[bool], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[torch.Tensor]], + list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + ]: # Get basic operations - idx = self._op_idxs["linear"] - linear_op = self.basic_ops[idx] - linear_op_ctx = basic_op_ctxs[idx] - if self._op_idxs["bias"] is None: - bias_op = None - bias_op_ctx = None - bias = None - else: - idx = self._op_idxs["bias"] - bias_op = self.basic_ops[idx] - bias_op_ctx = basic_op_ctxs[idx] - bias = bias_op.bias - if basic_op_kwargs[idx]: + linear_idx = self._op_idxs["linear"] + linear_op = self.basic_ops[linear_idx] + bias = None + if self._op_idxs["bias"] is not None: + bias_idx = self._op_idxs["bias"] + bias = self.basic_ops[bias_idx].bias + if basic_op_kwargs[bias_idx]: raise ValueError("Bias operation forward does not expect keyword arguments") - if self._op_idxs["activation"] is None: - activation_op = None # pylint: disable=unused-variable - else: + if self._op_idxs["activation"] is not None: raise NotImplementedError("Activations are not yet supported") # Check which grads are required - input_requires_grad = linear_op_ctx.requires_grad - weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad + input_requires_grad = requires_grad[linear_idx] + weight_requires_grad = requires_grad[linear_idx] and linear_op.weight.requires_grad # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = next_op_input_quantizer - grad_output_quantizer = linear_op.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override @@ -121,34 +115,54 @@ def fuser_forward( weight_requires_grad=weight_requires_grad, ) - # Save state for backward pass - if linear_op_ctx.requires_grad: + # Determine tensors to save for backward pass + if requires_grad[linear_idx]: if backward_override == "high_precision": saved_input = input_ if weight_requires_grad else None saved_weight = linear_op.weight if input_requires_grad else None else: saved_input = x_local saved_weight = w - if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) - linear_op_ctx.save_for_backward(saved_input, saved_weight) - linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and backward_override is None + linear_tensors = (saved_input, saved_weight) + else: + linear_tensors = (None, None) + + tensors_to_save = [() for _ in range(len(self.basic_ops))] + tensors_to_save[linear_idx] = linear_tensors + + return output, [() for _ in range(len(self.basic_ops))], tensors_to_save + + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + tensors_to_save: list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> None: + linear_idx = self._op_idxs["linear"] + linear_op = self.basic_ops[linear_idx] + linear_op.op_forward_save_ctx( + basic_op_ctxs[linear_idx], + input_, + tensors_to_save[linear_idx], + requires_grad=requires_grad[linear_idx], + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, + ) + if self._op_idxs["bias"] is not None: + bias_idx = self._op_idxs["bias"] + bias_op = self.basic_ops[bias_idx] + bias_op.op_forward_save_ctx( + basic_op_ctxs[bias_idx], + input_, + tensors_to_save[bias_idx], + requires_grad=requires_grad[bias_idx], + prev_op_grad_output_quantizer=linear_op.get_grad_output_quantizer(), ) - linear_op_ctx.backward_override = backward_override - linear_op_ctx.input_quantizer = input_quantizer - linear_op_ctx.weight_quantizer = weight_quantizer - linear_op_ctx.grad_output_quantizer = grad_output_quantizer - linear_op_ctx.grad_input_quantizer = grad_input_quantizer - linear_op_ctx.dtype = dtype - linear_op_ctx.input_requires_grad = input_requires_grad - linear_op_ctx.weight_requires_grad = weight_requires_grad - if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() - if backward_override is not None: - bias_op_ctx.grad_input_quantizer = None - - return output, [() for _ in range(len(self.basic_ops))] @staticmethod def fuse_forward_ops( diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 5376a7d264..abf1c119a8 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -6,11 +6,11 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any, Optional +from typing import Any, Optional, Union import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...quantized_tensor import QuantizedTensorStorage from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, Bias @@ -48,55 +48,45 @@ def __init__( # Index of each basic operations self._op_idxs: dict[str, Optional[int]] = op_idxs - def fuser_forward( + def fuser_forward_compute( self, - basic_op_ctxs: list[OperationContext], input_: torch.Tensor, *, + requires_grad: list[bool], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[torch.Tensor]], + list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + ]: # Get basic operations - idx = self._op_idxs["linear"] - linear_op = self.basic_ops[idx] - linear_op_ctx = basic_op_ctxs[idx] - if self._op_idxs["bias"] is None: - bias_op = None - bias_op_ctx = None - bias = None - else: - idx = self._op_idxs["bias"] - bias_op = self.basic_ops[idx] - bias_op_ctx = basic_op_ctxs[idx] - bias = bias_op.bias - if basic_op_kwargs[idx]: + linear_idx = self._op_idxs["linear"] + linear_op = self.basic_ops[linear_idx] + bias = None + if self._op_idxs["bias"] is not None: + bias_idx = self._op_idxs["bias"] + bias = self.basic_ops[bias_idx].bias + if basic_op_kwargs[bias_idx]: raise ValueError("Bias operation forward does not expect keyword arguments") # Check which grads are required - input_requires_grad = linear_op_ctx.requires_grad - weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad + input_requires_grad = requires_grad[linear_idx] + weight_requires_grad = requires_grad[linear_idx] and linear_op.weight.requires_grad # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None - grad_output_quantizer = linear_op.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: backward_override = None - # Get autocast dtype if needed - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") - else: - dtype = linear_op.weight.dtype - # Linear forward output = basic_op_extra_inputs[self._op_idxs["add"]][0] output, x_local, w = BasicLinear._functional_forward( @@ -118,34 +108,54 @@ def fuser_forward( weight_requires_grad=weight_requires_grad, ) - # Save state for backward pass - if linear_op_ctx.requires_grad: + # Determine tensors to save for backward pass + if requires_grad[linear_idx]: if backward_override == "high_precision": saved_input = input_ if weight_requires_grad else None saved_weight = linear_op.weight if input_requires_grad else None else: saved_input = x_local saved_weight = w - if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) - linear_op_ctx.save_for_backward(saved_input, saved_weight) - linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and backward_override is None - ) - linear_op_ctx.backward_override = backward_override - linear_op_ctx.input_quantizer = input_quantizer - linear_op_ctx.weight_quantizer = weight_quantizer - linear_op_ctx.grad_output_quantizer = grad_output_quantizer - linear_op_ctx.grad_input_quantizer = grad_input_quantizer - linear_op_ctx.dtype = dtype - linear_op_ctx.input_requires_grad = input_requires_grad - linear_op_ctx.weight_requires_grad = weight_requires_grad - if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = ( - None if backward_override is not None else linear_op.get_grad_output_quantizer() - ) + linear_tensors = (saved_input, saved_weight) + else: + linear_tensors = (None, None) + + tensors_to_save = [() for _ in range(len(self.basic_ops))] + tensors_to_save[linear_idx] = linear_tensors - return output, [() for _ in range(len(self.basic_ops))] + return output, [() for _ in range(len(self.basic_ops))], tensors_to_save + + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + tensors_to_save: list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> None: + linear_idx = self._op_idxs["linear"] + linear_op = self.basic_ops[linear_idx] + linear_op.op_forward_save_ctx( + basic_op_ctxs[linear_idx], + input_, + tensors_to_save[linear_idx], + requires_grad=requires_grad[linear_idx], + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, + ) + if self._op_idxs["bias"] is not None: + bias_idx = self._op_idxs["bias"] + bias_op = self.basic_ops[bias_idx] + bias_op.op_forward_save_ctx( + basic_op_ctxs[bias_idx], + input_, + tensors_to_save[bias_idx], + requires_grad=requires_grad[bias_idx], + prev_op_grad_output_quantizer=linear_op.get_grad_output_quantizer(), + ) @staticmethod def fuse_forward_ops( diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index abeb39adfa..8194d6a7a0 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -6,11 +6,11 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any, Optional +from typing import Any, Optional, Union import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...quantized_tensor import QuantizedTensorStorage from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, ConstantScale @@ -38,32 +38,33 @@ def __init__( ) -> None: super().__init__((linear, scale, add)) - def fuser_forward( + def fuser_forward_compute( self, - basic_op_ctxs: list[OperationContext], input_: torch.Tensor, *, + requires_grad: list[bool], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[torch.Tensor]], + list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + ]: # Get basic operations linear_op = self.basic_ops[0] - linear_op_ctx = basic_op_ctxs[0] scale_op = self.basic_ops[1] # Check which grads are required - input_requires_grad = linear_op_ctx.requires_grad - weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad + input_requires_grad = requires_grad[0] + weight_requires_grad = requires_grad[0] and linear_op.weight.requires_grad # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None - grad_output_quantizer = linear_op.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override @@ -99,30 +100,40 @@ def fuser_forward( weight_requires_grad=weight_requires_grad, ) - # Save state for backward pass - if linear_op_ctx.requires_grad: + # Determine tensors to save for backward pass + if requires_grad[0]: if backward_override == "high_precision": saved_input = input_ if weight_requires_grad else None saved_weight = linear_op.weight if input_requires_grad else None else: saved_input = x_local saved_weight = w - if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) - linear_op_ctx.save_for_backward(saved_input, saved_weight) - linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and backward_override is None - ) - linear_op_ctx.backward_override = backward_override - linear_op_ctx.input_quantizer = input_quantizer - linear_op_ctx.weight_quantizer = weight_quantizer - linear_op_ctx.grad_output_quantizer = grad_output_quantizer - linear_op_ctx.grad_input_quantizer = grad_input_quantizer - linear_op_ctx.dtype = dtype - linear_op_ctx.input_requires_grad = input_requires_grad - linear_op_ctx.weight_requires_grad = weight_requires_grad - - return output, [() for _ in range(len(self.basic_ops))] + linear_tensors = (saved_input, saved_weight) + else: + linear_tensors = (None, None) + + return output, [() for _ in range(len(self.basic_ops))], [linear_tensors, (), ()] + + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + tensors_to_save: list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> None: + linear_op = self.basic_ops[0] + linear_op.op_forward_save_ctx( + basic_op_ctxs[0], + input_, + tensors_to_save[0], + requires_grad=requires_grad[0], + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, + ) @staticmethod def fuse_forward_ops( diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 84073be6f8..2557bbfe0f 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -6,13 +6,13 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any, Optional +from typing import Any, Optional, Union import torch from transformer_engine_torch import CommOverlapType +from ...quantized_tensor import QuantizedTensorStorage from ...cpp_extensions import general_gemm -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import get_distributed_world_size from ...quantization import FP8GlobalStateManager from ...module.base import ( @@ -275,41 +275,36 @@ def _functional_forward( extra_outputs = {"input": x_local, "weight": w} return y_local, extra_outputs - def fuser_forward( + def fuser_forward_compute( self, - basic_op_ctxs: list[OperationContext], input_: torch.Tensor, *, + requires_grad: list[bool], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[torch.Tensor]], + list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + ]: # Get basic operations - idx = self._op_idxs["linear"] - linear_op = self.basic_ops[idx] - linear_op_ctx = basic_op_ctxs[idx] - bias_op = None - bias_op_ctx = None + linear_idx = self._op_idxs["linear"] + linear_op = self.basic_ops[linear_idx] bias = None if self._op_idxs["bias"] is not None: - idx = self._op_idxs["bias"] - bias_op = self.basic_ops[idx] - bias_op_ctx = basic_op_ctxs[idx] - bias = bias_op.bias - if basic_op_kwargs[idx]: + bias_idx = self._op_idxs["bias"] + bias = self.basic_ops[bias_idx].bias + if basic_op_kwargs[bias_idx]: raise ValueError("Bias operation forward does not expect keyword arguments") # Check which grads are required - input_requires_grad = linear_op_ctx.requires_grad - weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad + input_requires_grad = requires_grad[linear_idx] + weight_requires_grad = requires_grad[linear_idx] and linear_op.weight.requires_grad # Quantization metadata - input_quantizer = linear_op.get_quantizer("forward", 0) - weight_quantizer = linear_op.get_quantizer("forward", 1) - grad_output_quantizer = linear_op.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() @@ -318,6 +313,9 @@ def fuser_forward( f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})" ) + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) + # Get autocast dtype if needed if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") @@ -350,24 +348,44 @@ def fuser_forward( x_local = extra_outputs["input"] w = extra_outputs["weight"] - # Save state for backward pass - if linear_op_ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute - linear_op_ctx.input_quantizer = input_quantizer - linear_op_ctx.weight_quantizer = weight_quantizer - linear_op_ctx.grad_output_quantizer = grad_output_quantizer - linear_op_ctx.grad_input_quantizer = grad_input_quantizer - linear_op_ctx.dtype = dtype - linear_op_ctx.input_dims = input_.size() - linear_op_ctx.input_requires_grad = input_requires_grad - linear_op_ctx.weight_requires_grad = weight_requires_grad - if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() - - return output, [() for _ in range(len(self.basic_ops))] + tensors_to_save = [() for _ in range(len(self.basic_ops))] + tensors_to_save[linear_idx] = (x_local, w) + + return output, [() for _ in range(len(self.basic_ops))], tensors_to_save + + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + tensors_to_save: list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> None: + linear_idx = self._op_idxs["linear"] + linear_op = self.basic_ops[linear_idx] + linear_op.op_forward_save_ctx( + basic_op_ctxs[linear_idx], + input_, + tensors_to_save[linear_idx], + requires_grad=requires_grad[linear_idx], + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, + ) + if requires_grad[linear_idx]: + basic_op_ctxs[linear_idx].input_dims = input_.size() + if self._op_idxs["bias"] is not None: + bias_idx = self._op_idxs["bias"] + bias_op = self.basic_ops[bias_idx] + bias_op.op_forward_save_ctx( + basic_op_ctxs[bias_idx], + input_, + tensors_to_save[bias_idx], + requires_grad=requires_grad[bias_idx], + prev_op_grad_output_quantizer=linear_op.get_grad_output_quantizer(), + ) @staticmethod def fuse_forward_ops( diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index a3c7e1bac7..7dba1b6a34 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -128,14 +128,34 @@ def forward( if next_op is not None: next_op_input_quantizer = next_op.get_input_quantizer() - x, fused_op_extra_outputs = op.fuser_forward( - [basic_op_ctxs[idx] for idx in basic_op_idxs], - x, - basic_op_extra_inputs=extra_inputs, - prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, - next_op_input_quantizer=next_op_input_quantizer, - basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], - ) + if op._use_split_forward: + op_ctxs = [basic_op_ctxs[idx] for idx in basic_op_idxs] + fwd_kwargs = { + "requires_grad": [basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs], + "basic_op_extra_inputs": extra_inputs, + "prev_op_grad_output_quantizer": prev_op_grad_output_quantizer, + "next_op_input_quantizer": next_op_input_quantizer, + "basic_op_kwargs": [basic_op_kwargs[idx] for idx in basic_op_idxs], + } + x_input = x + x, fused_op_extra_outputs, tensors_to_save = op.fuser_forward_compute( + x_input, **fwd_kwargs + ) + op.fuser_forward_save_ctx( + op_ctxs, + x_input, + tensors_to_save, + **fwd_kwargs, + ) + else: + x, fused_op_extra_outputs = op.fuser_forward( + [basic_op_ctxs[idx] for idx in basic_op_idxs], + x, + basic_op_extra_inputs=extra_inputs, + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, + next_op_input_quantizer=next_op_input_quantizer, + basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], + ) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for y in ys: y.requires_grad_(idx >= fuser.first_op_requiring_backward) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index c5c8ea3463..e195372ab3 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -9,7 +9,7 @@ from collections.abc import Iterable import dataclasses import pickle -from typing import Any, Optional +from typing import Any, Optional, Union import torch @@ -19,6 +19,7 @@ RecipeState, autocast, ) +from ..quantized_tensor import QuantizedTensorStorage from ..tensor import Quantizer @@ -55,7 +56,41 @@ def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None: class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): - """Tensor operation supported by the operation fuser""" + """Tensor operation supported by the operation fuser + + Subclasses can define the forward pass using one of two APIs: + + - **Legacy API**: Override ``fuser_forward`` (single method that + performs both computation and context saving). + - **Split API**: Override both ``fuser_forward_compute`` and + ``fuser_forward_save_ctx`` (separates computation from context + saving, enabling ``torch.compile`` compatibility). + + The split API is preferred for new operations. If + ``fuser_forward_compute`` is defined, the operation automatically + uses the split API. Both methods must be defined together. + + """ + + _use_split_forward: bool = False + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + has_fuser_compute = "fuser_forward_compute" in cls.__dict__ + has_fuser_save_ctx = "fuser_forward_save_ctx" in cls.__dict__ + + if has_fuser_compute and not has_fuser_save_ctx: + raise TypeError( + f"{cls.__name__} defines fuser_forward_compute without " + "fuser_forward_save_ctx. Both must be defined together." + ) + if has_fuser_save_ctx and not has_fuser_compute: + raise TypeError( + f"{cls.__name__} defines fuser_forward_save_ctx without " + "fuser_forward_compute. Both must be defined together." + ) + if has_fuser_compute: + cls._use_split_forward = True @property @abc.abstractmethod @@ -90,6 +125,9 @@ def fuser_forward( ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: """Forward pass + Subclasses should implement either this method or both + ``fuser_forward_compute`` and ``fuser_forward_save_ctx``. + This op is either a basic op or the fusion of basic ops, so several of this function's arguments are lists of arguments to forward functions of corresponding basic ops. @@ -124,6 +162,97 @@ def fuser_forward( f"Forward pass is not implemented for operation ({self.__class__.__name__})" ) + def fuser_forward_compute( + self, + input_: torch.Tensor, + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[torch.Tensor]], + list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + ]: + """Forward computation without contexts + + Alternative to ``fuser_forward`` that separates computation + from context saving. Must be paired with + ``fuser_forward_save_ctx``. + + Parameters + ---------- + input_: torch.Tensor + Input tensor + requires_grad: list of bool + Whether backward pass is required, per basic op + basic_op_extra_inputs: list of tuple of torch.Tensor + Extra tensor inputs to basic operations + prev_op_grad_output_quantizer: Quantizer, optional + The grad_output_quantizer of the preceeding operation + next_op_input_quantizer: Quantizer, optional + The input_quantizer of the following operation + basic_op_kwargs: list of dict + Keyword arguments to forward functions of basic + operations. + + Returns + ------- + torch.Tensor + Output tensor + Iterable of iterable of torch.Tensor + Extra tensor outputs from basic operations + list of tuple of Optional[Union[torch.Tensor, QuantizedTensorStorage]] + Tensors to save for backward, per basic op + + """ + raise NotImplementedError( + f"fuser_forward_compute is not implemented for operation ({self.__class__.__name__})" + ) + + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + tensors_to_save: list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> None: + """Save forward state to contexts for backward pass + + Companion to ``fuser_forward_compute``. + + Parameters + ---------- + basic_op_ctxs: list of OperationContext + Contexts for basic operations + input_: torch.Tensor + Input tensor (same as passed to ``fuser_forward_compute``) + tensors_to_save: list of tuple of Optional[Union[torch.Tensor, QuantizedTensorStorage]] + Tensors returned by ``fuser_forward_compute``, per basic op + requires_grad: list of bool + Whether backward pass is required, per basic op + basic_op_extra_inputs: list of tuple of torch.Tensor + Extra tensor inputs to basic operations + prev_op_grad_output_quantizer: Quantizer, optional + The grad_output_quantizer of the preceeding operation + next_op_input_quantizer: Quantizer, optional + The input_quantizer of the following operation + basic_op_kwargs: list of dict + Keyword arguments to forward functions of basic + operations. + + """ + raise NotImplementedError( + f"fuser_forward_save_ctx is not implemented for operation ({self.__class__.__name__})" + ) + def fuser_backward( self, basic_op_ctxs: list[OperationContext], @@ -175,6 +304,18 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): This class holds parameters and state, even if the actual forward and backward passes are performed by a fused operation. + Subclasses can define the forward pass using one of two APIs: + + - **Legacy API**: Override ``op_forward`` (single method that + performs both computation and context saving). + - **Split API**: Override both ``op_forward_compute`` and + ``op_forward_save_ctx`` (separates computation from context + saving, enabling ``torch.compile`` compatibility). + + The split API is preferred for new operations. Defining both + ``op_forward`` and ``op_forward_compute`` is an error. The + split API methods must be defined together. + """ # Number of extra tensor inputs @@ -182,6 +323,33 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): # Number of extra tensor outputs num_extra_outputs: int = 0 + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + has_op_forward = "op_forward" in cls.__dict__ + has_compute = "op_forward_compute" in cls.__dict__ + has_save_ctx = "op_forward_save_ctx" in cls.__dict__ + + if has_op_forward and has_compute: + raise TypeError( + f"{cls.__name__} defines both op_forward and op_forward_compute. " + "Implement either op_forward (legacy) or " + "op_forward_compute + op_forward_save_ctx (split API), not both." + ) + if has_compute and not has_save_ctx: + raise TypeError( + f"{cls.__name__} defines op_forward_compute without op_forward_save_ctx. " + "Both must be defined together." + ) + if has_save_ctx and not has_compute: + raise TypeError( + f"{cls.__name__} defines op_forward_save_ctx without op_forward_compute. " + "Both must be defined together." + ) + if has_compute: + cls._use_split_forward = True + elif has_op_forward: + cls._use_split_forward = False + def __init__(self) -> None: super().__init__() @@ -410,18 +578,20 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) - @abc.abstractmethod def op_forward( self, ctx: OperationContext, input_: torch.Tensor, *, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, **kwargs: Any, ) -> torch.Tensor: """Forward pass + Subclasses should implement either this method or both + ``op_forward_compute`` and ``op_forward_save_ctx``. + Parameters ---------- ctx: OperationContext @@ -439,6 +609,85 @@ def op_forward( Output tensor """ + raise NotImplementedError( + f"Forward pass is not implemented for operation ({self.__class__.__name__}). " + "Implement either op_forward or both op_forward_compute and op_forward_save_ctx." + ) + + def op_forward_compute( + self, + input_: torch.Tensor, + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]: + """Forward computation without context + + Alternative to ``op_forward`` that separates computation from + context saving. Must be paired with ``op_forward_save_ctx``. + + Parameters + ---------- + input_: torch.Tensor + Input tensor + requires_grad: bool + Whether backward pass is required + prev_op_grad_output_quantizer: Quantizer, optional + The grad_output_quantizer of the preceeding operation + next_op_input_quantizer: Quantizer, optional + The input_quantizer of the following operation + + Returns + ------- + torch.Tensor + Output tensor + tuple of Optional[Union[torch.Tensor, QuantizedTensorStorage]] + Tensors to save for backward pass + + """ + raise NotImplementedError( + f"op_forward_compute is not implemented for operation ({self.__class__.__name__})" + ) + + def op_forward_save_ctx( + self, + ctx: OperationContext, + input_: torch.Tensor, + tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...], + *, + requires_grad: bool, + prev_op_grad_output_quantizer: Optional[Quantizer] = None, + next_op_input_quantizer: Optional[Quantizer] = None, + **kwargs: Any, + ) -> None: + """Save forward state to context for backward pass + + Companion to ``op_forward_compute``. Receives the same + arguments as ``op_forward_compute`` plus ``ctx`` and + ``tensors_to_save``. Override to save additional non-tensor + state needed for the backward pass. + + Parameters + ---------- + ctx: OperationContext + Context to coordinate between forward and backward passes + input_: torch.Tensor + Input tensor (same as passed to ``op_forward_compute``) + tensors_to_save: tuple of Optional[Union[torch.Tensor, QuantizedTensorStorage]] + Tensors returned by ``op_forward_compute`` + requires_grad: bool + Whether backward pass is required + prev_op_grad_output_quantizer: Quantizer, optional + The grad_output_quantizer of the preceeding operation + next_op_input_quantizer: Quantizer, optional + The input_quantizer of the following operation + + """ + raise NotImplementedError( + f"op_forward_save_ctx is not implemented for operation ({self.__class__.__name__})" + ) @abc.abstractmethod def op_backward( @@ -464,6 +713,16 @@ def op_backward( """ + def _check_no_extra_io(self, fuser_method: str, op_method: str) -> None: + """Raise if op has extra inputs/outputs and must override fuser-level method.""" + if self.num_extra_inputs > 0 or self.num_extra_outputs > 0: + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + f"It should override `{fuser_method}` instead of `{op_method}`." + ) + def fuser_forward( self, basic_op_ctxs: list[OperationContext], @@ -474,13 +733,7 @@ def fuser_forward( next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, list[tuple[()]]]: - if self.num_extra_inputs > 0 or self.num_extra_outputs > 0: - raise RuntimeError( - "{self.__class__.__name__} operation has " - f"{self.num_extra_inputs} extra tensor inputs " - f"and {self.num_extra_outputs} extra tensor outputs. " - "It should override `fuser_forward` instead of `op_forward`." - ) + self._check_no_extra_io("fuser_forward", "op_forward") output = self.op_forward( basic_op_ctxs[0], input_, @@ -490,6 +743,53 @@ def fuser_forward( ) return output, [()] + def fuser_forward_compute( + self, + input_: torch.Tensor, + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[ + torch.Tensor, + list[tuple[()]], + list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + ]: + self._check_no_extra_io("fuser_forward_compute", "op_forward_compute") + output, tensors_to_save = self.op_forward_compute( + input_, + requires_grad=requires_grad[0], + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, + next_op_input_quantizer=next_op_input_quantizer, + **basic_op_kwargs[0], + ) + return output, [()], [tensors_to_save] + + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + tensors_to_save: list[tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> None: + self._check_no_extra_io("fuser_forward_save_ctx", "op_forward_save_ctx") + self.op_forward_save_ctx( + basic_op_ctxs[0], + input_, + tensors_to_save[0], + requires_grad=requires_grad[0], + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, + next_op_input_quantizer=next_op_input_quantizer, + **basic_op_kwargs[0], + ) + def fuser_backward( self, basic_op_ctxs: list[OperationContext], @@ -501,13 +801,7 @@ def fuser_backward( list[Iterable[Optional[torch.Tensor]]], list[tuple[()]], ]: - if self.num_extra_inputs > 0 or self.num_extra_outputs > 0: - raise RuntimeError( - "{self.__class__.__name__} operation has " - f"{self.num_extra_inputs} extra tensor inputs " - f"and {self.num_extra_outputs} extra tensor outputs. " - "It should override `fuser_backward` instead of `op_backward`." - ) + self._check_no_extra_io("fuser_backward", "op_backward") grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output) return grad_input, [grad_params], [()] @@ -683,6 +977,13 @@ class FusedOperation(FusibleOperation): corresponding basic ops. This class should hold no parameters or other state, but should access them from the basic ops. + Subclasses can define the forward pass using one of two APIs + (inherited from ``FusibleOperation``): + + - **Legacy API**: Override ``fuser_forward``. + - **Split API**: Override both ``fuser_forward_compute`` and + ``fuser_forward_save_ctx``. + Parameters ---------- basic_ops : iterable of FusibleOperation