diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 92f73d5885..48b713ad66 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -35,6 +35,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_torch_compile.xml $TE_PATH/tests/pytorch/test_torch_compile.py || test_fail "test_torch_compile.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 309a5d124e..1286492a6e 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch.quantization import QuantizerRole from transformer_engine.pytorch import ( is_fp8_available, is_mxfp8_available, @@ -123,21 +124,25 @@ def __fx_repr__(self): _Q = get_opaque_type_name(ToyQuantizer) def _make_qfactory(tag: str): - """Return a qfactory that produces ToyQuantizer instances tagged with *tag*.""" + """Return a qfactory that produces ToyQuantizer instances tagged with *tag*. + + The factory dispatches on ``QuantizerRole.tensor_type``; the roles are + supplied by :meth:`ToyLinear.get_quantizer_roles`. + """ quantizers = { - role: ToyQuantizer(tag=f"{tag}:{role}") - for role in ( - "linear_input", - "linear_weight", - "linear_output", - "linear_grad_output", - "linear_grad_input", + tensor_type: ToyQuantizer(tag=f"{tag}:{tensor_type}") + for tensor_type in ( + "input", + "weight", + "output", + "grad_output", + "grad_input", ) } - def qfactory(role: str): - return quantizers[role] + def qfactory(role: QuantizerRole): + return quantizers[role.tensor_type] return qfactory @@ -163,6 +168,22 @@ def __init__( ) torch.nn.init.normal_(self.weight) + def get_quantizer_roles(self, *, fwd: bool, num_quantizers: int): + # Supplying explicit roles keeps CustomRecipeState from emitting a + # warning (which would graph-break under fullgraph=True) and lets the + # qfactory dispatch per tensor slot. Order must match the module's + # quantizer array (FP8FwdTensorIdx / FP8BwdTensorIdx). + if fwd: + return [ + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="output"), + ] + return [ + QuantizerRole(module_type="linear", tensor_type="grad_output"), + QuantizerRole(module_type="linear", tensor_type="grad_input"), + ] + def _get_weight_tensors(self): return [self.weight] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6c7ba8a8ab..3ca3813330 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -554,6 +554,12 @@ def get_ub(name: str, use_fp8: bool): return _ub_communicators[key] +@torch.compiler.assume_constant_result +def get_ub_is_fp8(name: str, use_fp8: bool) -> bool: + """Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant.""" + return get_ub(name, use_fp8).is_fp8_ubuf() + + def destroy_ub(): """Destroy all allocated userbuffer communicators.""" global _ub_communicators, _ub_with_cublasmp, _ub_initialized @@ -562,6 +568,9 @@ def destroy_ub(): _ub_initialized = False global layers_atomic_ring_exchange layers_atomic_ring_exchange = [] + # Compiled graphs may have baked is_fp8_ubuf() via assume_constant_result; + # reset so re-init with different settings doesn't read stale constants. + torch.compiler.reset() def fill_userbuffers_buffer_for_all_gather( @@ -1049,7 +1058,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): return if recipe.custom() and isinstance(recipe_state, CustomRecipeState): - return + if recipe_state.recipe is recipe: + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7fc96d4779..886a3d39a1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -21,6 +21,7 @@ from .base import ( fill_userbuffers_buffer_for_all_gather, get_ub, + get_ub_is_fp8, is_ub_initialized, using_cublasmp_backend, quantize_weight, @@ -1051,8 +1052,10 @@ def wgrad_gemm( if ctx.ln_out_needs_gather: # Gathered input is internal clear_tensor_data(ln_out_total) - if ctx.parallel_mode == "row" and ctx.sequence_parallel: - # Gathered grad output tensor is internal + if ctx.sequence_parallel and ( + ctx.parallel_mode == "row" or (ctx.parallel_mode == "column" and ctx.fp8) + ): + # Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM @@ -1668,14 +1671,10 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub( - self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub( - self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()): fp8_grad = True inp = self.prepare_forward( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6c6cca74ef..aa8fa67f87 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -23,6 +23,7 @@ fill_userbuffers_buffer_for_all_gather, _ub_communicators, get_ub, + get_ub_is_fp8, is_ub_initialized, using_cublasmp_backend, quantize_weight, @@ -2292,7 +2293,7 @@ def forward( fp8_output = False if self.ub_overlap_rs: - if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): + if get_ub_is_fp8("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()): fp8_output = True inp = self.prepare_forward(inp, num_gemms=2) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6c2d98d160..0593226965 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -21,6 +21,7 @@ fill_userbuffers_buffer_for_all_gather, get_dummy_wgrad, get_ub, + get_ub_is_fp8, is_ub_initialized, using_cublasmp_backend, quantize_weight, @@ -117,6 +118,8 @@ class LinearFwdArgs: fp8_output: bool save_original_input: bool backward_override: Optional[str] + dgrad_use_split_accumulator: bool + wgrad_use_split_accumulator: bool custom: bool debug: bool @@ -183,7 +186,8 @@ class LinearBwdArgs: # --- Numerical / dtype config --- activation_dtype: Optional[torch.dtype] = None fp8: bool = False - fp8_recipe: Optional[Recipe] = None + dgrad_use_split_accumulator: bool = _2X_ACC_DGRAD + wgrad_use_split_accumulator: bool = _2X_ACC_WGRAD backward_override: Optional[str] = None is_weight_param_quantized: bool = False custom: bool = False @@ -656,7 +660,8 @@ def _linear_setup_ctx( # Numerical / dtype config bwd_args.activation_dtype = fwd_args.activation_dtype bwd_args.fp8 = fp8 - bwd_args.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + bwd_args.dgrad_use_split_accumulator = fwd_args.dgrad_use_split_accumulator + bwd_args.wgrad_use_split_accumulator = fwd_args.wgrad_use_split_accumulator bwd_args.backward_override = backward_override bwd_args.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) bwd_args.custom = fwd_args.custom @@ -958,11 +963,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator - use_split_accumulator = _2X_ACC_DGRAD - if bwd_args.fp8: - recipe = bwd_args.fp8_recipe - if hasattr(recipe, "fp8_gemm_dgrad"): - use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + use_split_accumulator = bwd_args.dgrad_use_split_accumulator # Update grad input quantizer if grad_input_quantizer is not None: @@ -1133,11 +1134,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. grad_output = grad_output_quantizer(grad_output) # Figure out whether to use split accumulator - use_split_accumulator = _2X_ACC_WGRAD - if bwd_args.fp8: - recipe = bwd_args.fp8_recipe - if hasattr(recipe, "fp8_gemm_wgrad"): - use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator + use_split_accumulator = bwd_args.wgrad_use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad if bwd_args.is_first_microbatch is not None: @@ -1228,8 +1225,11 @@ def wgrad_gemm( elif bwd_args.backward_input_needs_gather: # Gathered input tensor is internal clear_tensor_data(inputmat_total) - if bwd_args.parallel_mode == "row" and bwd_args.sequence_parallel: - # Gathered grad output tensor is internal + if bwd_args.sequence_parallel and ( + bwd_args.parallel_mode == "row" + or (bwd_args.parallel_mode == "column" and bwd_args.fp8) + ): + # Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM @@ -1816,14 +1816,10 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub( - self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub( - self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()): fp8_grad = True inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) @@ -1861,8 +1857,15 @@ def forward( self._fp8_workspaces.get(cache_name) if cache_name is not None else None ) + dgrad_use_split_accumulator = _2X_ACC_DGRAD + wgrad_use_split_accumulator = _2X_ACC_WGRAD if self.fp8: - backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override + _recipe = FP8GlobalStateManager.get_fp8_recipe() + backward_override = _recipe.backward_override + if hasattr(_recipe, "fp8_gemm_dgrad"): + dgrad_use_split_accumulator = _recipe.fp8_gemm_dgrad.use_split_accumulator + if hasattr(_recipe, "fp8_gemm_wgrad"): + wgrad_use_split_accumulator = _recipe.fp8_gemm_wgrad.use_split_accumulator else: backward_override = None custom = is_custom(input_quantizer) or is_custom(weight_quantizer) @@ -1917,6 +1920,8 @@ def forward( fp8_output=fp8_output, save_original_input=self.save_original_input, backward_override=backward_override, + dgrad_use_split_accumulator=dgrad_use_split_accumulator, + wgrad_use_split_accumulator=wgrad_use_split_accumulator, custom=custom, debug=debug, # weight-workspace caching