diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 92f73d5885..48b713ad66 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -35,6 +35,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_torch_compile.xml $TE_PATH/tests/pytorch/test_torch_compile.py || test_fail "test_torch_compile.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 309a5d124e..1286492a6e 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch.quantization import QuantizerRole from transformer_engine.pytorch import ( is_fp8_available, is_mxfp8_available, @@ -123,21 +124,25 @@ def __fx_repr__(self): _Q = get_opaque_type_name(ToyQuantizer) def _make_qfactory(tag: str): - """Return a qfactory that produces ToyQuantizer instances tagged with *tag*.""" + """Return a qfactory that produces ToyQuantizer instances tagged with *tag*. + + The factory dispatches on ``QuantizerRole.tensor_type``; the roles are + supplied by :meth:`ToyLinear.get_quantizer_roles`. + """ quantizers = { - role: ToyQuantizer(tag=f"{tag}:{role}") - for role in ( - "linear_input", - "linear_weight", - "linear_output", - "linear_grad_output", - "linear_grad_input", + tensor_type: ToyQuantizer(tag=f"{tag}:{tensor_type}") + for tensor_type in ( + "input", + "weight", + "output", + "grad_output", + "grad_input", ) } - def qfactory(role: str): - return quantizers[role] + def qfactory(role: QuantizerRole): + return quantizers[role.tensor_type] return qfactory @@ -163,6 +168,22 @@ def __init__( ) torch.nn.init.normal_(self.weight) + def get_quantizer_roles(self, *, fwd: bool, num_quantizers: int): + # Supplying explicit roles keeps CustomRecipeState from emitting a + # warning (which would graph-break under fullgraph=True) and lets the + # qfactory dispatch per tensor slot. Order must match the module's + # quantizer array (FP8FwdTensorIdx / FP8BwdTensorIdx). + if fwd: + return [ + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="output"), + ] + return [ + QuantizerRole(module_type="linear", tensor_type="grad_output"), + QuantizerRole(module_type="linear", tensor_type="grad_input"), + ] + def _get_weight_tensors(self): return [self.weight] diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index bf5a230e84..b58b69acb3 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -16,6 +16,13 @@ from ..utils import get_default_init_method +def set_quantizer_amax_reduction_group(quantizer, amax_reduction_group) -> None: + """Set the amax reduction group on a quantizer; no-op if it doesn't support it.""" + if quantizer is not None and hasattr(quantizer, "with_amax_reduction"): + quantizer.with_amax_reduction = amax_reduction_group is not None + quantizer.amax_reduction_group = amax_reduction_group + + def _get_normalization_func(normalization: str, forward: bool): fwd_normalization_funcs = { "LayerNorm": tex.layernorm_fwd, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6c7ba8a8ab..3ca3813330 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -554,6 +554,12 @@ def get_ub(name: str, use_fp8: bool): return _ub_communicators[key] +@torch.compiler.assume_constant_result +def get_ub_is_fp8(name: str, use_fp8: bool) -> bool: + """Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant.""" + return get_ub(name, use_fp8).is_fp8_ubuf() + + def destroy_ub(): """Destroy all allocated userbuffer communicators.""" global _ub_communicators, _ub_with_cublasmp, _ub_initialized @@ -562,6 +568,9 @@ def destroy_ub(): _ub_initialized = False global layers_atomic_ring_exchange layers_atomic_ring_exchange = [] + # Compiled graphs may have baked is_fp8_ubuf() via assume_constant_result; + # reset so re-init with different settings doesn't read stale constants. + torch.compiler.reset() def fill_userbuffers_buffer_for_all_gather( @@ -1049,7 +1058,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): return if recipe.custom() and isinstance(recipe_state, CustomRecipeState): - return + if recipe_state.recipe is recipe: + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7fc96d4779..f187569d49 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -21,6 +21,7 @@ from .base import ( fill_userbuffers_buffer_for_all_gather, get_ub, + get_ub_is_fp8, is_ub_initialized, using_cublasmp_backend, quantize_weight, @@ -58,7 +59,12 @@ from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ._common import apply_normalization, noop_cat, WeightGradStore +from ._common import ( + apply_normalization, + noop_cat, + set_quantizer_amax_reduction_group, + WeightGradStore, +) from ..quantized_tensor import ( QuantizedTensor, QuantizedTensorStorage, @@ -215,6 +221,11 @@ def forward( if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) + # Amax reduction group for the input quantizer (column-parallel sequence parallel) + set_quantizer_amax_reduction_group( + input_quantizer, + tp_group if (sequence_parallel and parallel_mode == "column") else None, + ) # Avoid quantized norm kernel if norm output will be returned # or if a gather of ln_out must be in high precision. @@ -690,6 +701,15 @@ def backward( # tensor usage at a time. Configure quantizer with # usage for only dgrad GEMM. quantizer.set_usage(columnwise=False) + # Amax reduction group for grad output (row-parallel sequence parallel) + set_quantizer_amax_reduction_group( + quantizer, + ( + ctx.tp_group + if (ctx.sequence_parallel and ctx.parallel_mode == "row") + else None + ), + ) # Prepare grad output tensor # Note: Cast to expected dtype and perform tensor-parallel communication @@ -1051,8 +1071,10 @@ def wgrad_gemm( if ctx.ln_out_needs_gather: # Gathered input is internal clear_tensor_data(ln_out_total) - if ctx.parallel_mode == "row" and ctx.sequence_parallel: - # Gathered grad output tensor is internal + if ctx.sequence_parallel and ( + ctx.parallel_mode == "row" or (ctx.parallel_mode == "column" and ctx.fp8) + ): + # Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM @@ -1552,8 +1574,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.nvfp4(): - self._customize_quantizers_nvfp4(fwd, recipe) def get_quantizer_roles( self, @@ -1668,14 +1688,10 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub( - self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub( - self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()): fp8_grad = True inp = self.prepare_forward( @@ -1919,15 +1935,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_fwd"][ FP8FwdTensorIdx.GEMM1_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon - # parallel related - if self.sequence_parallel and self.parallel_mode == "column": - # set input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group else: # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) self.quantizers["scaling_bwd"][ @@ -1936,37 +1943,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - # parallel related - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group - - def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on current scaling recipe + layernorm_linear.""" - assert recipe.nvfp4(), "Incorrect recipe." - if fwd: - if self.sequence_parallel and self.parallel_mode == "column": - # set input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group - else: - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6c6cca74ef..695dd98916 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -23,6 +23,7 @@ fill_userbuffers_buffer_for_all_gather, _ub_communicators, get_ub, + get_ub_is_fp8, is_ub_initialized, using_cublasmp_backend, quantize_weight, @@ -73,7 +74,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ._common import apply_normalization, WeightGradStore +from ._common import apply_normalization, set_quantizer_amax_reduction_group, WeightGradStore from ..cpu_offload import ( is_cpu_offload_enabled, start_offload, @@ -399,6 +400,11 @@ def _forward( if sequence_parallel and fc1_input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data fc1_input_quantizer.set_usage(columnwise=False) + # Amax reduction group for the FC1 input quantizer (column-parallel sequence parallel) + set_quantizer_amax_reduction_group( + fc1_input_quantizer, + tp_group if (sequence_parallel and set_parallel_mode) else None, + ) # for fp8 DelayedScaling: layernorm output = FP8 # only output of the linear is returned @@ -1138,6 +1144,11 @@ def backward( # tensor usage at a time. Configure quantizer with # usage for only dgrad GEMM. quantizer.set_usage(columnwise=False) + # Amax reduction group for FC2 grad output (row-parallel sequence parallel) + set_quantizer_amax_reduction_group( + quantizer, + ctx.tp_group if (ctx.sequence_parallel and ctx.set_parallel_mode) else None, + ) # Prepare FC2 grad output tensor # Note: Cast to expected dtype and perform tensor-parallel communication @@ -2165,8 +2176,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.nvfp4(): - self._customize_quantizers_nvfp4(fwd, recipe) def get_quantizer_roles( self, @@ -2292,7 +2301,7 @@ def forward( fp8_output = False if self.ub_overlap_rs: - if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): + if get_ub_is_fp8("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()): fp8_output = True inp = self.prepare_forward(inp, num_gemms=2) @@ -2676,15 +2685,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_fwd"][ FP8FwdTensorIdx.GEMM2_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon - # parallel related - if self.sequence_parallel and self.set_parallel_mode: - # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group else: # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer self.quantizers["scaling_bwd"][ @@ -2700,36 +2700,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - if self.sequence_parallel and self.set_parallel_mode: - # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT2 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT2 - ].amax_reduction_group = self.tp_group - - def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on current scaling recipe + layernorm_mlp.""" - assert recipe.nvfp4(), "Incorrect recipe." - if fwd: - if self.sequence_parallel and self.set_parallel_mode: - # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group - else: - if self.sequence_parallel and self.set_parallel_mode: - # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT2 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT2 - ].amax_reduction_group = self.tp_group def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6c2d98d160..a2afad271a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -21,6 +21,7 @@ fill_userbuffers_buffer_for_all_gather, get_dummy_wgrad, get_ub, + get_ub_is_fp8, is_ub_initialized, using_cublasmp_backend, quantize_weight, @@ -29,7 +30,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ._common import noop_cat, WeightGradStore +from ._common import noop_cat, set_quantizer_amax_reduction_group, WeightGradStore from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( cast_if_needed, @@ -117,6 +118,8 @@ class LinearFwdArgs: fp8_output: bool save_original_input: bool backward_override: Optional[str] + dgrad_use_split_accumulator: bool + wgrad_use_split_accumulator: bool custom: bool debug: bool @@ -183,7 +186,8 @@ class LinearBwdArgs: # --- Numerical / dtype config --- activation_dtype: Optional[torch.dtype] = None fp8: bool = False - fp8_recipe: Optional[Recipe] = None + dgrad_use_split_accumulator: bool = _2X_ACC_DGRAD + wgrad_use_split_accumulator: bool = _2X_ACC_WGRAD backward_override: Optional[str] = None is_weight_param_quantized: bool = False custom: bool = False @@ -302,6 +306,12 @@ def _linear_forward_impl( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) + # Amax reduction group for the input quantizer (column-parallel sequence parallel) + set_quantizer_amax_reduction_group( + input_quantizer, + tp_group if (sequence_parallel and parallel_mode == "column") else None, + ) + # Configure Userbuffers communication (comm+GEMM overlap) ub_obj = None ub_type = None @@ -656,7 +666,8 @@ def _linear_setup_ctx( # Numerical / dtype config bwd_args.activation_dtype = fwd_args.activation_dtype bwd_args.fp8 = fp8 - bwd_args.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + bwd_args.dgrad_use_split_accumulator = fwd_args.dgrad_use_split_accumulator + bwd_args.wgrad_use_split_accumulator = fwd_args.wgrad_use_split_accumulator bwd_args.backward_override = backward_override bwd_args.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) bwd_args.custom = fwd_args.custom @@ -743,6 +754,24 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. grad_weight_quantizer = args.grad_weight_quantizer grad_output_quantizer = args.grad_output_quantizer + # Amax reduction groups (sequence parallel): input for column-parallel, grad output for row-parallel + set_quantizer_amax_reduction_group( + input_quantizer, + ( + bwd_args.tp_group + if (bwd_args.sequence_parallel and bwd_args.parallel_mode == "column") + else None + ), + ) + set_quantizer_amax_reduction_group( + grad_output_quantizer, + ( + bwd_args.tp_group + if (bwd_args.sequence_parallel and bwd_args.parallel_mode == "row") + else None + ), + ) + # NVTX label for profiling nvtx_label = "transformer_engine._Linear.backward" if bwd_args.ub_name is not None: @@ -958,11 +987,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator - use_split_accumulator = _2X_ACC_DGRAD - if bwd_args.fp8: - recipe = bwd_args.fp8_recipe - if hasattr(recipe, "fp8_gemm_dgrad"): - use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + use_split_accumulator = bwd_args.dgrad_use_split_accumulator # Update grad input quantizer if grad_input_quantizer is not None: @@ -1133,11 +1158,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. grad_output = grad_output_quantizer(grad_output) # Figure out whether to use split accumulator - use_split_accumulator = _2X_ACC_WGRAD - if bwd_args.fp8: - recipe = bwd_args.fp8_recipe - if hasattr(recipe, "fp8_gemm_wgrad"): - use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator + use_split_accumulator = bwd_args.wgrad_use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad if bwd_args.is_first_microbatch is not None: @@ -1228,8 +1249,11 @@ def wgrad_gemm( elif bwd_args.backward_input_needs_gather: # Gathered input tensor is internal clear_tensor_data(inputmat_total) - if bwd_args.parallel_mode == "row" and bwd_args.sequence_parallel: - # Gathered grad output tensor is internal + if bwd_args.sequence_parallel and ( + bwd_args.parallel_mode == "row" + or (bwd_args.parallel_mode == "column" and bwd_args.fp8) + ): + # Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM @@ -1746,8 +1770,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.nvfp4(): - self._customize_quantizers_nvfp4(fwd, recipe) def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -1816,14 +1838,10 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub( - self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub( - self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()): fp8_grad = True inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) @@ -1861,8 +1879,15 @@ def forward( self._fp8_workspaces.get(cache_name) if cache_name is not None else None ) + dgrad_use_split_accumulator = _2X_ACC_DGRAD + wgrad_use_split_accumulator = _2X_ACC_WGRAD if self.fp8: - backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override + _recipe = FP8GlobalStateManager.get_fp8_recipe() + backward_override = _recipe.backward_override + if hasattr(_recipe, "fp8_gemm_dgrad"): + dgrad_use_split_accumulator = _recipe.fp8_gemm_dgrad.use_split_accumulator + if hasattr(_recipe, "fp8_gemm_wgrad"): + wgrad_use_split_accumulator = _recipe.fp8_gemm_wgrad.use_split_accumulator else: backward_override = None custom = is_custom(input_quantizer) or is_custom(weight_quantizer) @@ -1917,6 +1942,8 @@ def forward( fp8_output=fp8_output, save_original_input=self.save_original_input, backward_override=backward_override, + dgrad_use_split_accumulator=dgrad_use_split_accumulator, + wgrad_use_split_accumulator=wgrad_use_split_accumulator, custom=custom, debug=debug, # weight-workspace caching @@ -2111,15 +2138,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_fwd"][ FP8FwdTensorIdx.GEMM1_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon - # paralle related - if self.sequence_parallel and self.parallel_mode == "column": - # customize input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group else: # set grad_output_quantizer with amax epsilon and power_2_scale self.quantizers["scaling_bwd"][ @@ -2128,37 +2146,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - # parallel related - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group - - def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on current scaling recipe + linear.""" - assert recipe.nvfp4(), "Incorrect recipe." - if fwd: - if self.sequence_parallel and self.parallel_mode == "column": - # customize input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group - else: - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 6b17d66fcd..cb429055a4 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -25,6 +25,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from ...module._common import set_quantizer_amax_reduction_group from ...tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer from ...tensor.storage.float8_tensor_storage import Float8TensorStorage @@ -401,23 +402,6 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon - if getattr(self, "sequence_parallel", False): - tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) - if tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - elif tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group - if recipe.nvfp4(): - if getattr(self, "sequence_parallel", False): - tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) - if tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - elif tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group # Update quantizer in quantized weight tensor if weight_quantizer is not None and is_quantized_tensor(weight): @@ -544,6 +528,10 @@ def _functional_forward( rowwise=True, columnwise=weight_requires_grad and backward_override is None, ) + # Amax reduction group for the input quantizer (column-parallel sequence parallel) + set_quantizer_amax_reduction_group( + input_quantizer, tensor_parallel_group if with_x_all_gather else None + ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) x, x_async = gather_along_first_dim( @@ -788,6 +776,10 @@ def _functional_backward( rowwise=input_requires_grad, columnwise=weight_requires_grad, ) + # Amax reduction group for grad output (row-parallel sequence parallel) + set_quantizer_amax_reduction_group( + grad_output_quantizer, tensor_parallel_group if with_dy_all_gather else None + ) if with_dy_all_gather: dy, dy_async = gather_along_first_dim( dy_local, @@ -828,6 +820,10 @@ def _functional_backward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage(rowwise=False, columnwise=True) + # Amax reduction group for the input quantizer (column-parallel sequence parallel) + set_quantizer_amax_reduction_group( + input_quantizer, tensor_parallel_group if with_x_all_gather else None + ) if with_x_all_gather: x, x_async = gather_along_first_dim( x_local, diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 404796fd63..cfe488aae5 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -288,8 +288,15 @@ def quantize( if out is not None: return self.update_quantized(tensor, out) if (not self.internal) and torch.is_grad_enabled(): - return _QuantizeFunc.apply(tensor, self.quantize_impl) - return _QuantizeFunc.forward(None, tensor, self.quantize_impl) + result = _QuantizeFunc.apply(tensor, self.quantize_impl) + else: + result = _QuantizeFunc.forward(None, tensor, self.quantize_impl) + # The amax reduction group must never persist on a tensor's quantizer + result_quantizer = getattr(result, "_quantizer", None) + if getattr(result_quantizer, "with_amax_reduction", False): + result_quantizer.with_amax_reduction = False + result_quantizer.amax_reduction_group = None + return result def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 4de8d82217..e26abf7df0 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -282,8 +282,16 @@ def update_quantized( if not src.is_contiguous(): src = src.contiguous() + # Apply the destination tensor's amax reduction group on a throwaway copy + quantizer = self + group = getattr(dst, "amax_reduction_group", None) + if group is not None: + quantizer = self.copy() + quantizer.with_amax_reduction = True + quantizer.amax_reduction_group = group + # Launch cast kernel - tex.quantize(src, self, dst, noop_flag) + tex.quantize(src, quantizer, dst, noop_flag) # Update FP8 dtype dst._fp8_dtype = self.dtype @@ -411,6 +419,9 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): """ + # Optional amax all-reduce group, set by FSDP2 in ``fsdp_pre_all_gather`` + amax_reduction_group: Optional[dist_group_type] = None + def __repr__(self, *, tensor_contents=None): return ( "Float8Tensor(" @@ -785,12 +796,8 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m from transformer_engine.pytorch.distributed import _get_module_fsdp_state if isinstance(self._quantizer, Float8CurrentScalingQuantizer) and mesh is not None: - # When sharded weight is updated after reduce scattering the gradients in FSDP2, - # we need to do amax reduction across the mesh to make sure all weight shards are - # updated with same scale inverse. Setting the state below in the quantizer will make - # sure that updated Quantized weight tensor have same scale inverse across all shards. - self._quantizer.amax_reduction_group = mesh.get_group() - self._quantizer.with_amax_reduction = True + # Reduce amax across the mesh so all weight shards get the same scale inverse + self.amax_reduction_group = mesh.get_group() fsdp_state = _get_module_fsdp_state(module) param_group = fsdp_state._fsdp_param_group @@ -995,7 +1002,14 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Quantize to FP8 assert self._quantizer is not None, "Can't quantize without a quantizer" self._quantizer.internal = False - self.data = self._quantizer.quantize(tensor) + # Apply this tensor's amax reduction group (set by FSDP2) on a throwaway copy + quantizer = self._quantizer + group = getattr(self, "amax_reduction_group", None) + if group is not None and isinstance(quantizer, Float8CurrentScalingQuantizer): + quantizer = quantizer.copy() + quantizer.with_amax_reduction = True + quantizer.amax_reduction_group = group + self.data = quantizer.quantize(tensor) if self.requires_grad != tensor.requires_grad: self.requires_grad_(requires_grad=tensor.requires_grad) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 5a2765b9f5..aa92be004f 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -200,8 +200,16 @@ def update_quantized( if not src.is_contiguous(): src = src.contiguous() + # Apply the destination tensor's amax reduction group on a throwaway copy + quantizer = self + group = getattr(dst, "amax_reduction_group", None) + if group is not None: + quantizer = self.copy() + quantizer.with_amax_reduction = True + quantizer.amax_reduction_group = group + # Launch cast kernel - tex.quantize(src, self, dst, noop_flag) + tex.quantize(src, quantizer, dst, noop_flag) return dst @@ -359,6 +367,9 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): Nominal tensor datatype, used in dequantize. """ + # Optional amax all-reduce group, set by FSDP2 in ``fsdp_pre_all_gather`` + amax_reduction_group: Optional[dist_group_type] = None + # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorStorage with positional args, # which significantly reduces the Pybind11 overhead when calling the constructor from C++. def __new__( @@ -513,6 +524,10 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m "FSDP2 is not supported for NVFP4Tensors with GEMM-swizzled scales." ) + if mesh is not None: + # Reduce amax across the mesh so all weight shards get the same scale + self.amax_reduction_group = mesh.get_group() + shard_M = math.prod(self.shape[:-1]) assert shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, (