From bfce3a7d02a0c87e0c3472bd40f2fadf68a0e6a4 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 15 Jun 2026 13:28:53 +0200 Subject: [PATCH 1/5] [PyTorch] torch.compile: wrap pybind11 UB methods as compile-time constants; fix SP memory leak; test suite hook-up MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wrap CommOverlapCore pybind11 methods that return compile-time constants so torch.compile(fullgraph=True) can trace through them without graph breaks: - `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py; `_ub_is_fp8()` in gemm.py - `with_cublasmp()` → `ub_is_cublasmp()` in base.py All callers in linear.py, layernorm_linear.py, layernorm_mlp.py, base.py, gemm.py, userbuffers_backward_linear.py and userbuffers_forward_linear.py updated. Fix quantized grad_output not being freed early for column-parallel SP backward. Row-parallel SP already called clear_tensor_data(grad_output) to release the gathered tensor; column-parallel SP quantizes grad_output to Float8TensorStorage but never freed it before returning. Under torch.compile reduce-overhead this leaves 3 live pool tensors at recording end and triggers "Detected 3 tensor(s) in the cudagraph pool not tracked as outputs". Extend the existing clear_tensor_data guard to cover both parallel modes. Fix custom-recipe quantizer state being re-initialised on every forward call even when the recipe object has not changed. The existing early-exit for CustomRecipeState was missing an identity check on the recipe object, so any repeated call with the same recipe would bypass the early-return and rebuild quantizers unnecessarily. Add `if recipe_state.recipe is recipe: return` to restore the intended caching behaviour. Add test_torch_compile.py to L0_pytorch_unittest so the autocast and existing compile tests run in CI. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Pawel Gadzinski --- qa/L0_pytorch_unittest/test.sh | 1 + transformer_engine/pytorch/module/base.py | 9 +++++++- .../pytorch/module/layernorm_linear.py | 22 +++++++++++-------- .../pytorch/module/layernorm_mlp.py | 3 ++- transformer_engine/pytorch/module/linear.py | 16 +++++++------- 5 files changed, 32 insertions(+), 19 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 92f73d5885..48b713ad66 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -35,6 +35,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_torch_compile.xml $TE_PATH/tests/pytorch/test_torch_compile.py || test_fail "test_torch_compile.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6c7ba8a8ab..962f168831 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -554,6 +554,12 @@ def get_ub(name: str, use_fp8: bool): return _ub_communicators[key] +@torch.compiler.assume_constant_result +def get_ub_is_fp8(name: str, use_fp8: bool) -> bool: + """Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant.""" + return get_ub(name, use_fp8).is_fp8_ubuf() + + def destroy_ub(): """Destroy all allocated userbuffer communicators.""" global _ub_communicators, _ub_with_cublasmp, _ub_initialized @@ -1049,7 +1055,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): return if recipe.custom() and isinstance(recipe_state, CustomRecipeState): - return + if recipe_state.recipe is recipe: + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7fc96d4779..6d3b31e818 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -21,6 +21,7 @@ from .base import ( fill_userbuffers_buffer_for_all_gather, get_ub, + get_ub_is_fp8, is_ub_initialized, using_cublasmp_backend, quantize_weight, @@ -413,7 +414,11 @@ def forward( if ub_overlap_rs_fprop: # cuBLASMp writes the reduce-scattered output directly into the # GEMM output tensor; Userbuffers writes it into the extra-output buffer. - out = gemm_out if ub_obj is not None and ub_obj.with_cublasmp() else reduce_scatter_out + out = ( + gemm_out + if ub_obj is not None and ub_obj.with_cublasmp() + else reduce_scatter_out + ) elif parallel_mode == "row" and tp_size > 1: nvtx_range_push(f"{nvtx_label}.row_parallel_comm") out = gemm_out @@ -1051,8 +1056,11 @@ def wgrad_gemm( if ctx.ln_out_needs_gather: # Gathered input is internal clear_tensor_data(ln_out_total) - if ctx.parallel_mode == "row" and ctx.sequence_parallel: - # Gathered grad output tensor is internal + if ctx.sequence_parallel and ( + ctx.parallel_mode == "row" + or (ctx.parallel_mode == "column" and ctx.fp8) + ): + # Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM @@ -1668,14 +1676,10 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub( - self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub( - self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()): fp8_grad = True inp = self.prepare_forward( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6c6cca74ef..aa8fa67f87 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -23,6 +23,7 @@ fill_userbuffers_buffer_for_all_gather, _ub_communicators, get_ub, + get_ub_is_fp8, is_ub_initialized, using_cublasmp_backend, quantize_weight, @@ -2292,7 +2293,7 @@ def forward( fp8_output = False if self.ub_overlap_rs: - if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): + if get_ub_is_fp8("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()): fp8_output = True inp = self.prepare_forward(inp, num_gemms=2) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6c2d98d160..d8d0d001c5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -21,6 +21,7 @@ fill_userbuffers_buffer_for_all_gather, get_dummy_wgrad, get_ub, + get_ub_is_fp8, is_ub_initialized, using_cublasmp_backend, quantize_weight, @@ -1228,8 +1229,11 @@ def wgrad_gemm( elif bwd_args.backward_input_needs_gather: # Gathered input tensor is internal clear_tensor_data(inputmat_total) - if bwd_args.parallel_mode == "row" and bwd_args.sequence_parallel: - # Gathered grad output tensor is internal + if bwd_args.sequence_parallel and ( + bwd_args.parallel_mode == "row" + or (bwd_args.parallel_mode == "column" and bwd_args.fp8) + ): + # Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM @@ -1816,14 +1820,10 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub( - self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub( - self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): + if get_ub_is_fp8(self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()): fp8_grad = True inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) From 0fd25df007d7b95f02deb899de43babbb284d50b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 15 Jun 2026 13:32:04 +0200 Subject: [PATCH 2/5] [PyTorch] Replace fp8_recipe in LinearBwdArgs with pre-resolved split-accumulator booleans LinearBwdArgs stored the entire FP8 recipe object so the backward could extract fp8_gemm_dgrad.use_split_accumulator and fp8_gemm_wgrad.use_split_accumulator at GEMM time. Recipe objects hold process-group references and are not serialisable as compile-time constants, making them incompatible with torch.compile custom-op paths. Replace fp8_recipe with two plain bool fields: - dgrad_use_split_accumulator (default _2X_ACC_DGRAD) - wgrad_use_split_accumulator (default _2X_ACC_WGRAD) These are resolved once in _linear_setup_ctx and passed into the args struct, so the backward consumes scalars instead of a live recipe object. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/linear.py | 31 ++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index d8d0d001c5..0593226965 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -118,6 +118,8 @@ class LinearFwdArgs: fp8_output: bool save_original_input: bool backward_override: Optional[str] + dgrad_use_split_accumulator: bool + wgrad_use_split_accumulator: bool custom: bool debug: bool @@ -184,7 +186,8 @@ class LinearBwdArgs: # --- Numerical / dtype config --- activation_dtype: Optional[torch.dtype] = None fp8: bool = False - fp8_recipe: Optional[Recipe] = None + dgrad_use_split_accumulator: bool = _2X_ACC_DGRAD + wgrad_use_split_accumulator: bool = _2X_ACC_WGRAD backward_override: Optional[str] = None is_weight_param_quantized: bool = False custom: bool = False @@ -657,7 +660,8 @@ def _linear_setup_ctx( # Numerical / dtype config bwd_args.activation_dtype = fwd_args.activation_dtype bwd_args.fp8 = fp8 - bwd_args.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + bwd_args.dgrad_use_split_accumulator = fwd_args.dgrad_use_split_accumulator + bwd_args.wgrad_use_split_accumulator = fwd_args.wgrad_use_split_accumulator bwd_args.backward_override = backward_override bwd_args.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) bwd_args.custom = fwd_args.custom @@ -959,11 +963,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator - use_split_accumulator = _2X_ACC_DGRAD - if bwd_args.fp8: - recipe = bwd_args.fp8_recipe - if hasattr(recipe, "fp8_gemm_dgrad"): - use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + use_split_accumulator = bwd_args.dgrad_use_split_accumulator # Update grad input quantizer if grad_input_quantizer is not None: @@ -1134,11 +1134,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. grad_output = grad_output_quantizer(grad_output) # Figure out whether to use split accumulator - use_split_accumulator = _2X_ACC_WGRAD - if bwd_args.fp8: - recipe = bwd_args.fp8_recipe - if hasattr(recipe, "fp8_gemm_wgrad"): - use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator + use_split_accumulator = bwd_args.wgrad_use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad if bwd_args.is_first_microbatch is not None: @@ -1861,8 +1857,15 @@ def forward( self._fp8_workspaces.get(cache_name) if cache_name is not None else None ) + dgrad_use_split_accumulator = _2X_ACC_DGRAD + wgrad_use_split_accumulator = _2X_ACC_WGRAD if self.fp8: - backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override + _recipe = FP8GlobalStateManager.get_fp8_recipe() + backward_override = _recipe.backward_override + if hasattr(_recipe, "fp8_gemm_dgrad"): + dgrad_use_split_accumulator = _recipe.fp8_gemm_dgrad.use_split_accumulator + if hasattr(_recipe, "fp8_gemm_wgrad"): + wgrad_use_split_accumulator = _recipe.fp8_gemm_wgrad.use_split_accumulator else: backward_override = None custom = is_custom(input_quantizer) or is_custom(weight_quantizer) @@ -1917,6 +1920,8 @@ def forward( fp8_output=fp8_output, save_original_input=self.save_original_input, backward_override=backward_override, + dgrad_use_split_accumulator=dgrad_use_split_accumulator, + wgrad_use_split_accumulator=wgrad_use_split_accumulator, custom=custom, debug=debug, # weight-workspace caching From afe364bb6ef8004a03129b9e39047df6871317a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Jun 2026 14:42:37 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_linear.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6d3b31e818..886a3d39a1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -414,11 +414,7 @@ def forward( if ub_overlap_rs_fprop: # cuBLASMp writes the reduce-scattered output directly into the # GEMM output tensor; Userbuffers writes it into the extra-output buffer. - out = ( - gemm_out - if ub_obj is not None and ub_obj.with_cublasmp() - else reduce_scatter_out - ) + out = gemm_out if ub_obj is not None and ub_obj.with_cublasmp() else reduce_scatter_out elif parallel_mode == "row" and tp_size > 1: nvtx_range_push(f"{nvtx_label}.row_parallel_comm") out = gemm_out @@ -1057,8 +1053,7 @@ def wgrad_gemm( # Gathered input is internal clear_tensor_data(ln_out_total) if ctx.sequence_parallel and ( - ctx.parallel_mode == "row" - or (ctx.parallel_mode == "column" and ctx.fp8) + ctx.parallel_mode == "row" or (ctx.parallel_mode == "column" and ctx.fp8) ): # Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal clear_tensor_data(grad_output) From ee43f5620182b3ccf66728dd3b2ec0396c4513de Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 16 Jun 2026 14:05:02 +0200 Subject: [PATCH 4/5] Reset torch.compile state in destroy_ub to avoid stale assume_constant_result get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a reset, destroy_ub + re-init with different FP8 settings would read stale values until recompile. Only affects in-memory caches, not disk. Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 962f168831..3ca3813330 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -568,6 +568,9 @@ def destroy_ub(): _ub_initialized = False global layers_atomic_ring_exchange layers_atomic_ring_exchange = [] + # Compiled graphs may have baked is_fp8_ubuf() via assume_constant_result; + # reset so re-init with different settings doesn't read stale constants. + torch.compiler.reset() def fill_userbuffers_buffer_for_all_gather( From 22f80e40846365bd46bb97a5febb10ca02889136 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 16 Jun 2026 16:32:11 +0200 Subject: [PATCH 5/5] Provide explicit QuantizerRoles in torch.compile custom-recipe test ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit the no-roles warning, which graph-breaks under fullgraph=True. qfactory dispatches on role.tensor_type instead of a pre-baked string key. Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 41 ++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 309a5d124e..1286492a6e 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch.quantization import QuantizerRole from transformer_engine.pytorch import ( is_fp8_available, is_mxfp8_available, @@ -123,21 +124,25 @@ def __fx_repr__(self): _Q = get_opaque_type_name(ToyQuantizer) def _make_qfactory(tag: str): - """Return a qfactory that produces ToyQuantizer instances tagged with *tag*.""" + """Return a qfactory that produces ToyQuantizer instances tagged with *tag*. + + The factory dispatches on ``QuantizerRole.tensor_type``; the roles are + supplied by :meth:`ToyLinear.get_quantizer_roles`. + """ quantizers = { - role: ToyQuantizer(tag=f"{tag}:{role}") - for role in ( - "linear_input", - "linear_weight", - "linear_output", - "linear_grad_output", - "linear_grad_input", + tensor_type: ToyQuantizer(tag=f"{tag}:{tensor_type}") + for tensor_type in ( + "input", + "weight", + "output", + "grad_output", + "grad_input", ) } - def qfactory(role: str): - return quantizers[role] + def qfactory(role: QuantizerRole): + return quantizers[role.tensor_type] return qfactory @@ -163,6 +168,22 @@ def __init__( ) torch.nn.init.normal_(self.weight) + def get_quantizer_roles(self, *, fwd: bool, num_quantizers: int): + # Supplying explicit roles keeps CustomRecipeState from emitting a + # warning (which would graph-break under fullgraph=True) and lets the + # qfactory dispatch per tensor slot. Order must match the module's + # quantizer array (FP8FwdTensorIdx / FP8BwdTensorIdx). + if fwd: + return [ + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="output"), + ] + return [ + QuantizerRole(module_type="linear", tensor_type="grad_output"), + QuantizerRole(module_type="linear", tensor_type="grad_input"), + ] + def _get_weight_tensors(self): return [self.weight]