Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ Operation fuser
:members: forward

.. autoapiclass:: transformer_engine.pytorch.ops.FusibleOperation
:members: fuser_forward, fuser_backward
:members: fuser_forward, fuser_forward_compute, fuser_forward_save_ctx, fuser_backward

.. autoapiclass:: transformer_engine.pytorch.ops.BasicOperation
:members: op_forward, op_backward
:members: op_forward, op_forward_compute, op_forward_save_ctx, op_backward

.. autoapiclass:: transformer_engine.pytorch.ops.FusedOperation
:members: fuser_forward, fuser_backward
:members: fuser_forward, fuser_forward_compute, fuser_forward_save_ctx, fuser_backward

.. autoapifunction:: transformer_engine.pytorch.ops.register_forward_fusion

Expand Down
46 changes: 32 additions & 14 deletions transformer_engine/pytorch/ops/basic/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

from __future__ import annotations
import abc
from typing import Optional
from typing import Optional, Union

import torch

import transformer_engine_torch as tex
from ...quantized_tensor import QuantizedTensorStorage
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
Expand Down Expand Up @@ -80,13 +81,14 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:

"""

def op_forward(
def op_forward_compute(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
*,
requires_grad: bool,
prev_op_grad_output_quantizer: Optional[Quantizer] = None,
next_op_input_quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]:

# Compute dtype
dtype: torch.dtype
Expand All @@ -109,15 +111,31 @@ def op_forward(
input_quantizer.set_usage(rowwise=True, columnwise=False)
x = input_quantizer(x)

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
if requires_grad:
return y, (x,)
return y, (None,)

return y
def op_forward_save_ctx(
self,
ctx: OperationContext,
input_: torch.Tensor,
tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...],
*,
requires_grad: bool,
prev_op_grad_output_quantizer: Optional[Quantizer] = None,
next_op_input_quantizer: Optional[Quantizer] = None,
) -> None:
if not requires_grad:
return
(x,) = tensors_to_save
if is_cpu_offload_enabled():
mark_activation_offload(x)
ctx.save_for_backward(x)
if torch.is_autocast_enabled():
ctx.dtype = torch.get_autocast_dtype("cuda")
else:
ctx.dtype = input_.dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer

def op_backward(
self,
Expand Down
32 changes: 19 additions & 13 deletions transformer_engine/pytorch/ops/basic/add_extra_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,40 +42,46 @@ def __init__(self, *, in_place: bool = False):
super().__init__()
self._in_place = in_place

def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)

def op_backward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)

def fuser_forward(
def fuser_forward_compute(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
requires_grad: list[bool],
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]], list[tuple[()]]]:
extra_input = basic_op_extra_inputs[0][0]
if self._in_place:
extra_input = extra_input.detach()
extra_input += input_
output = extra_input
else:
output = extra_input + input_
return output, [()]
return output, [()], [()]

def fuser_forward_save_ctx(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
tensors_to_save: list[tuple[()]],
*,
requires_grad: list[bool],
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> None:
pass

def fuser_backward(
self,
Expand Down
25 changes: 19 additions & 6 deletions transformer_engine/pytorch/ops/basic/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,32 @@ def __init__(
self.process_group: Optional[torch.distributed.ProcessGroup] = process_group
self.process_group_size: int = torch.distributed.get_world_size(process_group)

def op_forward(
def op_forward_compute(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
*,
requires_grad: bool,
prev_op_grad_output_quantizer: Optional[Quantizer] = None,
next_op_input_quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, tuple[()]]:
out: torch.Tensor
if self.process_group_size == 1:
out = input_.detach()
else:
out, _ = gather_along_first_dim(input_, self.process_group)
return out
return out, ()

def op_forward_save_ctx(
self,
ctx: OperationContext,
input_: torch.Tensor,
tensors_to_save: tuple[()],
*,
requires_grad: bool,
prev_op_grad_output_quantizer: Optional[Quantizer] = None,
next_op_input_quantizer: Optional[Quantizer] = None,
) -> None:
pass

def op_backward(
self,
Expand Down
27 changes: 20 additions & 7 deletions transformer_engine/pytorch/ops/basic/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,35 @@ def __init__(
self.process_group: Optional[torch.distributed.ProcessGroup] = process_group
self._reduce_in_backward: bool = reduce_in_backward

def op_forward(
def op_forward_compute(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
*,
requires_grad: bool,
prev_op_grad_output_quantizer: Optional[Quantizer] = None,
next_op_input_quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, tuple[()]]:

# Trivial case
if torch.distributed.get_world_size(self.process_group) == 1:
return input_
return input_, ()

# Perform all-reduce
x = maybe_dequantize(input_.contiguous())
torch.distributed.all_reduce(x, group=self.process_group)
return x
return x, ()

def op_forward_save_ctx(
self,
ctx: OperationContext,
input_: torch.Tensor,
tensors_to_save: tuple[()],
*,
requires_grad: bool,
prev_op_grad_output_quantizer: Optional[Quantizer] = None,
next_op_input_quantizer: Optional[Quantizer] = None,
) -> None:
pass

def op_backward(
self,
Expand Down
80 changes: 54 additions & 26 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from collections.abc import Callable, Iterable
import contextlib
import math
from typing import Any, Optional
from typing import Any, Optional, Union

import torch

from ...quantized_tensor import QuantizedTensorStorage
from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import (
Expand Down Expand Up @@ -979,24 +980,23 @@ def _functional_backward(
_wait_async(dx_async)
return dx, dw

def op_forward(
def op_forward_compute(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
*,
requires_grad: bool,
prev_op_grad_output_quantizer: Optional[Quantizer] = None,
next_op_input_quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...]]:

# Check which grads are required
input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
input_requires_grad = requires_grad
weight_requires_grad = requires_grad and self.weight.requires_grad

# Quantizers
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override
Expand Down Expand Up @@ -1026,28 +1026,56 @@ def op_forward(
weight_requires_grad=weight_requires_grad,
)

# Save state for backward pass
if ctx.requires_grad:
# Determine tensors to save for backward pass
if requires_grad:
if backward_override == "high_precision":
saved_input = input_ if weight_requires_grad else None
saved_weight = self.weight if input_requires_grad else None
else:
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
ctx.save_for_backward(saved_input, saved_weight)
ctx.with_quantized_compute = with_quantized_compute and backward_override is None
ctx.backward_override = backward_override
ctx.input_quantizer = input_quantizer
ctx.weight_quantizer = weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.dtype = dtype
ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad

return output
else:
saved_input = None
saved_weight = None

return output, (saved_input, saved_weight)

def op_forward_save_ctx(
self,
ctx: OperationContext,
input_: torch.Tensor,
tensors_to_save: tuple[Optional[Union[torch.Tensor, QuantizedTensorStorage]], ...],
*,
requires_grad: bool,
prev_op_grad_output_quantizer: Optional[Quantizer] = None,
next_op_input_quantizer: Optional[Quantizer] = None,
) -> None:
if not requires_grad:
return

saved_input, saved_weight = tensors_to_save
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
ctx.save_for_backward(saved_input, saved_weight)

with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override
else:
backward_override = None

ctx.with_quantized_compute = with_quantized_compute and backward_override is None
ctx.backward_override = backward_override
ctx.input_quantizer = self.get_quantizer("forward", 0)
ctx.weight_quantizer = self.get_quantizer("forward", 1)
ctx.grad_output_quantizer = self.get_quantizer("backward", 0)
ctx.grad_input_quantizer = prev_op_grad_output_quantizer
if torch.is_autocast_enabled():
ctx.dtype = torch.get_autocast_dtype("cuda")
else:
ctx.dtype = self.weight.dtype
ctx.input_requires_grad = requires_grad
ctx.weight_requires_grad = requires_grad and self.weight.requires_grad

def op_backward(
self,
Expand Down
Loading
Loading