Skip to content

[torch.compile] Bunch of small changes needed for enabling torch.compile#3130

Open
pggPL wants to merge 5 commits into
NVIDIA:mainfrom
pggPL:torch_compile_small_fixes
Open

[torch.compile] Bunch of small changes needed for enabling torch.compile#3130
pggPL wants to merge 5 commits into
NVIDIA:mainfrom
pggPL:torch_compile_small_fixes

Conversation

@pggPL

@pggPL pggPL commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Description

Small standalone fixes extracted from a larger torch.compile branch, going directly from main. Two independent changes: making Userbuffers pybind11 queries compile-friendly, and freeing quantized grad_output early for column-parallel SP. Plus a custom-recipe caching fix, a split-accumulator refactor, and a CI test hook-up.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  1. Userbuffers pybind11 queries under torch.compile
  • is_fp8_ubuf() / with_cublasmp() are compile-time constants but graph-break when traced. At the nn.Module.forward boundary (where no UB communicator object is in hand yet) they go through get_ub_is_fp8(name, use_fp8), wrapped in torch.compiler.assume_constant_result — only plain (str, bool) args are baked, so guards are well-defined and don't rely on pybind-object identity.
  • In the hot forward/backward implementation paths the UB communicator is already fetched, so those call ub_obj.is_fp8_ubuf() / ub_obj.with_cublasmp() directly — no wrapper, no string concatenation, no redundant registry lookup. Eager speed is preserved.
  1. Free quantized grad_output early for column-parallel SP
  • Row-parallel SP already called clear_tensor_data(grad_output) on the gathered tensor. Column-parallel SP quantizes grad_output to a Float8TensorStorage (an internal tensor) but never freed it. Under torch.compile reduce-overhead this left live pool tensors at recording end ("Detected N tensor(s) in the cudagraph pool not tracked as outputs"). The free now covers row-SP and column-SP-FP8 (column-SP non-FP8 is a no-op view, so it's excluded).
  1. Replace fp8_recipe in LinearBwdArgs with pre-resolved split-accumulator booleans
  • LinearBwdArgs no longer carries the recipe object (which holds process-group references and is compile-unfriendly). dgrad_use_split_accumulator / wgrad_use_split_accumulator are resolved once in Linear.forward (reusing the existing get_fp8_recipe() call) and threaded through as plain booleans.
  1. Custom-recipe quantizer caching fix
  • CustomRecipeState early-exit was missing an identity check, so quantizers were rebuilt on every forward even when the recipe was unchanged. Added if recipe_state.recipe is recipe: return.
  1. Test hook-up
  • Added test_torch_compile.py to L0_pytorch_unittest.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 2 commits June 15, 2026 16:40
…stants; fix SP memory leak; test suite hook-up

Wrap CommOverlapCore pybind11 methods that return compile-time constants
so torch.compile(fullgraph=True) can trace through them without graph
breaks:
- `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py;
  `_ub_is_fp8()` in gemm.py
- `with_cublasmp()` → `ub_is_cublasmp()` in base.py

All callers in linear.py, layernorm_linear.py, layernorm_mlp.py,
base.py, gemm.py, userbuffers_backward_linear.py and
userbuffers_forward_linear.py updated.

Fix quantized grad_output not being freed early for column-parallel SP
backward. Row-parallel SP already called clear_tensor_data(grad_output)
to release the gathered tensor; column-parallel SP quantizes grad_output
to Float8TensorStorage but never freed it before returning.  Under
torch.compile reduce-overhead this leaves 3 live pool tensors at
recording end and triggers "Detected 3 tensor(s) in the cudagraph pool
not tracked as outputs".  Extend the existing clear_tensor_data guard to
cover both parallel modes.

Fix custom-recipe quantizer state being re-initialised on every forward
call even when the recipe object has not changed. The existing early-exit
for CustomRecipeState was missing an identity check on the recipe object,
so any repeated call with the same recipe would bypass the early-return
and rebuild quantizers unnecessarily.  Add `if recipe_state.recipe is
recipe: return` to restore the intended caching behaviour.

Add test_torch_compile.py to L0_pytorch_unittest so the autocast and
existing compile tests run in CI.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans

LinearBwdArgs stored the entire FP8 recipe object so the backward could
extract fp8_gemm_dgrad.use_split_accumulator and
fp8_gemm_wgrad.use_split_accumulator at GEMM time.  Recipe objects hold
process-group references and are not serialisable as compile-time
constants, making them incompatible with torch.compile custom-op paths.

Replace fp8_recipe with two plain bool fields:
- dgrad_use_split_accumulator (default _2X_ACC_DGRAD)
- wgrad_use_split_accumulator (default _2X_ACC_WGRAD)

These are resolved once in _linear_setup_ctx and passed into the args
struct, so the backward consumes scalars instead of a live recipe object.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL requested a review from ksivaman as a code owner June 15, 2026 14:41
@greptile-apps

greptile-apps Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR bundles five targeted changes to make transformer_engine compatible with torch.compile: wrapping UB communicator property lookups in @torch.compiler.assume_constant_result, freeing quantized grad_output early in the column-parallel SP FP8 backward, removing the fp8_recipe object from LinearBwdArgs in favour of pre-resolved split-accumulator booleans, fixing a missing identity check in CustomRecipeState early-exit, and wiring a new test file into the L0 CI suite.

  • UB compile-friendly queriesget_ub_is_fp8(name, use_fp8) is added as a wrapper decorated with @torch.compiler.assume_constant_result; forward-boundary call sites in Linear, LayerNormLinear, and LayerNormMLP are updated to it. destroy_ub() calls torch.compiler.reset() to prevent stale baked constants on re-init.
  • Split-accumulator refactorLinearBwdArgs.fp8_recipe is replaced with two plain bool fields (dgrad_use_split_accumulator, wgrad_use_split_accumulator) resolved at forward time, removing a process-group–carrying object from the backward graph.
  • clear_tensor_data / CustomRecipeState fixes – both linear.py and layernorm_linear.py now also free the quantized grad_output in the column-SP FP8 path; CustomRecipeState early-exit now checks recipe_state.recipe is recipe to avoid rebuilding quantizers on every forward when the recipe is unchanged.

Confidence Score: 5/5

All five changes are narrowly scoped and semantically equivalent to the code they replace; no regressions were identified.

The split-accumulator refactor resolves the same recipe values at forward time that the old code did, and threads them as plain booleans — the backward behavior is unchanged. The clear_tensor_data extension correctly identifies the column-SP FP8 path as the only one where a new internal tensor is created (non-FP8 column-SP is a reshape/contiguous view and is deliberately excluded). The CustomRecipeState identity fix is a clear correctness improvement that prevents unnecessary quantizer rebuilds. The torch.compiler.reset() call in destroy_ub() is the right API to invalidate baked assume_constant_result caches on re-init. LinearFwdArgs is constructed in exactly one place, so the two new required fields cannot be missed.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Adds get_ub_is_fp8 wrapped with assume_constant_result; adds torch.compiler.reset() to destroy_ub(); fixes CustomRecipeState early-exit with identity check. All three changes are correct and well-scoped.
transformer_engine/pytorch/module/linear.py Replaces fp8_recipe in LinearBwdArgs with pre-resolved split-accumulator booleans; adds clear_tensor_data(grad_output) for column-SP FP8 path; switches forward-boundary UB queries to get_ub_is_fp8. Logic is equivalent to the removed recipe-object path and correctness is preserved.
transformer_engine/pytorch/module/layernorm_linear.py Extends clear_tensor_data(grad_output) to column-SP FP8 path; switches UB queries to get_ub_is_fp8. Note: ctx.fp8_recipe is still saved for split-accumulator use in the backward — not addressed in this PR.
transformer_engine/pytorch/module/layernorm_mlp.py Single-line change: UB query in LayerNormMLP.forward switched from get_ub(...).is_fp8_ubuf() to get_ub_is_fp8(...). No other changes in this file.
qa/L0_pytorch_unittest/test.sh Adds test_torch_compile.py to the L0 PyTorch unit-test CI suite. Straightforward addition following existing patterns.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Fwd as Module.forward
    participant UBReg as UB Registry
    participant Compiler as torch.compiler
    participant BwdArgs as LinearBwdArgs
    participant Bwd as _linear_backward

    Note over Fwd,Compiler: UB query — compile-friendly path
    Fwd->>Compiler: get_ub_is_fp8(name, use_fp8) [assume_constant_result]
    Compiler-->>UBReg: get_ub(name, use_fp8).is_fp8_ubuf()
    UBReg-->>Fwd: bool (baked as compile-time constant)

    Note over Fwd,BwdArgs: Split-accumulator resolution (new)
    Fwd->>Fwd: get_fp8_recipe() → resolve dgrad/wgrad split-acc bools
    Fwd->>BwdArgs: dgrad_use_split_accumulator, wgrad_use_split_accumulator (plain bool)

    Note over Bwd,BwdArgs: Backward — no recipe object needed
    Bwd->>BwdArgs: read dgrad_use_split_accumulator
    Bwd->>BwdArgs: read wgrad_use_split_accumulator

    Note over Bwd: column-SP FP8 cleanup (new)
    Bwd->>Bwd: clear_tensor_data(grad_output) if column-SP and fp8

    Note over UBReg,Compiler: Re-init path
    UBReg->>Compiler: destroy_ub() → torch.compiler.reset()
    Note over Compiler: All baked assume_constant_result caches invalidated
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Fwd as Module.forward
    participant UBReg as UB Registry
    participant Compiler as torch.compiler
    participant BwdArgs as LinearBwdArgs
    participant Bwd as _linear_backward

    Note over Fwd,Compiler: UB query — compile-friendly path
    Fwd->>Compiler: get_ub_is_fp8(name, use_fp8) [assume_constant_result]
    Compiler-->>UBReg: get_ub(name, use_fp8).is_fp8_ubuf()
    UBReg-->>Fwd: bool (baked as compile-time constant)

    Note over Fwd,BwdArgs: Split-accumulator resolution (new)
    Fwd->>Fwd: get_fp8_recipe() → resolve dgrad/wgrad split-acc bools
    Fwd->>BwdArgs: dgrad_use_split_accumulator, wgrad_use_split_accumulator (plain bool)

    Note over Bwd,BwdArgs: Backward — no recipe object needed
    Bwd->>BwdArgs: read dgrad_use_split_accumulator
    Bwd->>BwdArgs: read wgrad_use_split_accumulator

    Note over Bwd: column-SP FP8 cleanup (new)
    Bwd->>Bwd: clear_tensor_data(grad_output) if column-SP and fp8

    Note over UBReg,Compiler: Re-init path
    UBReg->>Compiler: destroy_ub() → torch.compiler.reset()
    Note over Compiler: All baked assume_constant_result caches invalidated
Loading

Reviews (2): Last reviewed commit: "Reset torch.compile state in destroy_ub ..." | Re-trigger Greptile

Comment on lines +557 to +560
@torch.compiler.assume_constant_result
def get_ub_is_fp8(name: str, use_fp8: bool) -> bool:
"""Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant."""
return get_ub(name, use_fp8).is_fp8_ubuf()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 assume_constant_result can become stale after destroy_ub() + re-init

@torch.compiler.assume_constant_result caches the return value per (name, use_fp8) argument pair for the lifetime of a compiled region. If destroy_ub() is called and UB communicators are re-initialized with different FP8 settings (e.g. in a test harness that re-creates the communicators), the cached is_fp8_ubuf() result would be silently stale until the next recompile. In production training this should not happen — UB is typically initialized once — but test suites that tear down and rebuild UB communicators between cases could observe incorrect fp8_output/fp8_grad flags without triggering a recompile.

@pggPL

pggPL commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

pggPL added 2 commits June 16, 2026 14:05
…t_result

get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a
reset, destroy_ub + re-init with different FP8 settings would read stale
values until recompile. Only affects in-memory caches, not disk.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit
the no-roles warning, which graph-breaks under fullgraph=True. qfactory
dispatches on role.tensor_type instead of a pre-baked string key.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant