Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_torch_compile.xml $TE_PATH/tests/pytorch/test_torch_compile.py || test_fail "test_torch_compile.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
Expand Down
41 changes: 31 additions & 10 deletions tests/pytorch/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.quantization import QuantizerRole
from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
Expand Down Expand Up @@ -123,21 +124,25 @@ def __fx_repr__(self):
_Q = get_opaque_type_name(ToyQuantizer)

def _make_qfactory(tag: str):
"""Return a qfactory that produces ToyQuantizer instances tagged with *tag*."""
"""Return a qfactory that produces ToyQuantizer instances tagged with *tag*.

The factory dispatches on ``QuantizerRole.tensor_type``; the roles are
supplied by :meth:`ToyLinear.get_quantizer_roles`.
"""

quantizers = {
role: ToyQuantizer(tag=f"{tag}:{role}")
for role in (
"linear_input",
"linear_weight",
"linear_output",
"linear_grad_output",
"linear_grad_input",
tensor_type: ToyQuantizer(tag=f"{tag}:{tensor_type}")
for tensor_type in (
"input",
"weight",
"output",
"grad_output",
"grad_input",
)
}

def qfactory(role: str):
return quantizers[role]
def qfactory(role: QuantizerRole):
return quantizers[role.tensor_type]

return qfactory

Expand All @@ -163,6 +168,22 @@ def __init__(
)
torch.nn.init.normal_(self.weight)

def get_quantizer_roles(self, *, fwd: bool, num_quantizers: int):
# Supplying explicit roles keeps CustomRecipeState from emitting a
# warning (which would graph-break under fullgraph=True) and lets the
# qfactory dispatch per tensor slot. Order must match the module's
# quantizer array (FP8FwdTensorIdx / FP8BwdTensorIdx).
if fwd:
return [
QuantizerRole(module_type="linear", tensor_type="input"),
QuantizerRole(module_type="linear", tensor_type="weight"),
QuantizerRole(module_type="linear", tensor_type="output"),
]
return [
QuantizerRole(module_type="linear", tensor_type="grad_output"),
QuantizerRole(module_type="linear", tensor_type="grad_input"),
]

def _get_weight_tensors(self):
return [self.weight]

Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from ..utils import get_default_init_method


def set_quantizer_amax_reduction_group(quantizer, amax_reduction_group) -> None:
"""Set the amax reduction group on a quantizer; no-op if it doesn't support it."""
if quantizer is not None and hasattr(quantizer, "with_amax_reduction"):
quantizer.with_amax_reduction = amax_reduction_group is not None
quantizer.amax_reduction_group = amax_reduction_group


def _get_normalization_func(normalization: str, forward: bool):
fwd_normalization_funcs = {
"LayerNorm": tex.layernorm_fwd,
Expand Down
12 changes: 11 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,12 @@ def get_ub(name: str, use_fp8: bool):
return _ub_communicators[key]


@torch.compiler.assume_constant_result
def get_ub_is_fp8(name: str, use_fp8: bool) -> bool:
"""Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant."""
return get_ub(name, use_fp8).is_fp8_ubuf()


def destroy_ub():
"""Destroy all allocated userbuffer communicators."""
global _ub_communicators, _ub_with_cublasmp, _ub_initialized
Expand All @@ -562,6 +568,9 @@ def destroy_ub():
_ub_initialized = False
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []
# Compiled graphs may have baked is_fp8_ubuf() via assume_constant_result;
# reset so re-init with different settings doesn't read stale constants.
torch.compiler.reset()


def fill_userbuffers_buffer_for_all_gather(
Expand Down Expand Up @@ -1049,7 +1058,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState):
return
if recipe.custom() and isinstance(recipe_state, CustomRecipeState):
return
if recipe_state.recipe is recipe:
return

# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
Expand Down
78 changes: 27 additions & 51 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
get_ub_is_fp8,
is_ub_initialized,
using_cublasmp_backend,
quantize_weight,
Expand Down Expand Up @@ -58,7 +59,12 @@
from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore
from ._common import (
apply_normalization,
noop_cat,
set_quantizer_amax_reduction_group,
WeightGradStore,
)
from ..quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
Expand Down Expand Up @@ -215,6 +221,11 @@ def forward(
if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
# Amax reduction group for the input quantizer (column-parallel sequence parallel)
set_quantizer_amax_reduction_group(
input_quantizer,
tp_group if (sequence_parallel and parallel_mode == "column") else None,
)

# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
Expand Down Expand Up @@ -690,6 +701,15 @@ def backward(
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Amax reduction group for grad output (row-parallel sequence parallel)
set_quantizer_amax_reduction_group(
quantizer,
(
ctx.tp_group
if (ctx.sequence_parallel and ctx.parallel_mode == "row")
else None
),
)

# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
Expand Down Expand Up @@ -1051,8 +1071,10 @@ def wgrad_gemm(
if ctx.ln_out_needs_gather:
# Gathered input is internal
clear_tensor_data(ln_out_total)
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
# Gathered grad output tensor is internal
if ctx.sequence_parallel and (
ctx.parallel_mode == "row" or (ctx.parallel_mode == "column" and ctx.fp8)
):
# Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal
clear_tensor_data(grad_output)

# Update grad input if overlapping reduce-scatter with wgrad GEMM
Expand Down Expand Up @@ -1552,8 +1574,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)

def get_quantizer_roles(
self,
Expand Down Expand Up @@ -1668,14 +1688,10 @@ def forward(
is_first_microbatch = False

if self.ub_overlap_rs_fprop:
if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
if get_ub_is_fp8(self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()):
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
if get_ub_is_fp8(self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()):
fp8_grad = True

inp = self.prepare_forward(
Expand Down Expand Up @@ -1919,15 +1935,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "column":
# set input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
# set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here)
self.quantizers["scaling_bwd"][
Expand All @@ -1936,37 +1943,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group

def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group

def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
Expand Down
56 changes: 13 additions & 43 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
fill_userbuffers_buffer_for_all_gather,
_ub_communicators,
get_ub,
get_ub_is_fp8,
is_ub_initialized,
using_cublasmp_backend,
quantize_weight,
Expand Down Expand Up @@ -73,7 +74,7 @@
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ._common import apply_normalization, set_quantizer_amax_reduction_group, WeightGradStore
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
Expand Down Expand Up @@ -399,6 +400,11 @@ def _forward(
if sequence_parallel and fc1_input_quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False)
# Amax reduction group for the FC1 input quantizer (column-parallel sequence parallel)
set_quantizer_amax_reduction_group(
fc1_input_quantizer,
tp_group if (sequence_parallel and set_parallel_mode) else None,
)

# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
Expand Down Expand Up @@ -1138,6 +1144,11 @@ def backward(
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Amax reduction group for FC2 grad output (row-parallel sequence parallel)
set_quantizer_amax_reduction_group(
quantizer,
ctx.tp_group if (ctx.sequence_parallel and ctx.set_parallel_mode) else None,
)

# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
Expand Down Expand Up @@ -2165,8 +2176,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)

def get_quantizer_roles(
self,
Expand Down Expand Up @@ -2292,7 +2301,7 @@ def forward(

fp8_output = False
if self.ub_overlap_rs:
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
if get_ub_is_fp8("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()):
fp8_output = True

inp = self.prepare_forward(inp, num_gemms=2)
Expand Down Expand Up @@ -2676,15 +2685,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM2_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# parallel related
if self.sequence_parallel and self.set_parallel_mode:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self.quantizers["scaling_bwd"][
Expand All @@ -2700,36 +2700,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group

def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group

def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
Expand Down
Loading
Loading