[torch.compile] Bunch of small changes needed for enabling torch.compile#3130
[torch.compile] Bunch of small changes needed for enabling torch.compile#3130pggPL wants to merge 5 commits into
Conversation
…stants; fix SP memory leak; test suite hook-up Wrap CommOverlapCore pybind11 methods that return compile-time constants so torch.compile(fullgraph=True) can trace through them without graph breaks: - `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py; `_ub_is_fp8()` in gemm.py - `with_cublasmp()` → `ub_is_cublasmp()` in base.py All callers in linear.py, layernorm_linear.py, layernorm_mlp.py, base.py, gemm.py, userbuffers_backward_linear.py and userbuffers_forward_linear.py updated. Fix quantized grad_output not being freed early for column-parallel SP backward. Row-parallel SP already called clear_tensor_data(grad_output) to release the gathered tensor; column-parallel SP quantizes grad_output to Float8TensorStorage but never freed it before returning. Under torch.compile reduce-overhead this leaves 3 live pool tensors at recording end and triggers "Detected 3 tensor(s) in the cudagraph pool not tracked as outputs". Extend the existing clear_tensor_data guard to cover both parallel modes. Fix custom-recipe quantizer state being re-initialised on every forward call even when the recipe object has not changed. The existing early-exit for CustomRecipeState was missing an identity check on the recipe object, so any repeated call with the same recipe would bypass the early-return and rebuild quantizers unnecessarily. Add `if recipe_state.recipe is recipe: return` to restore the intended caching behaviour. Add test_torch_compile.py to L0_pytorch_unittest so the autocast and existing compile tests run in CI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans LinearBwdArgs stored the entire FP8 recipe object so the backward could extract fp8_gemm_dgrad.use_split_accumulator and fp8_gemm_wgrad.use_split_accumulator at GEMM time. Recipe objects hold process-group references and are not serialisable as compile-time constants, making them incompatible with torch.compile custom-op paths. Replace fp8_recipe with two plain bool fields: - dgrad_use_split_accumulator (default _2X_ACC_DGRAD) - wgrad_use_split_accumulator (default _2X_ACC_WGRAD) These are resolved once in _linear_setup_ctx and passed into the args struct, so the backward consumes scalars instead of a live recipe object. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR bundles five targeted changes to make
Confidence Score: 5/5All five changes are narrowly scoped and semantically equivalent to the code they replace; no regressions were identified. The split-accumulator refactor resolves the same recipe values at forward time that the old code did, and threads them as plain booleans — the backward behavior is unchanged. The No files require special attention. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant Fwd as Module.forward
participant UBReg as UB Registry
participant Compiler as torch.compiler
participant BwdArgs as LinearBwdArgs
participant Bwd as _linear_backward
Note over Fwd,Compiler: UB query — compile-friendly path
Fwd->>Compiler: get_ub_is_fp8(name, use_fp8) [assume_constant_result]
Compiler-->>UBReg: get_ub(name, use_fp8).is_fp8_ubuf()
UBReg-->>Fwd: bool (baked as compile-time constant)
Note over Fwd,BwdArgs: Split-accumulator resolution (new)
Fwd->>Fwd: get_fp8_recipe() → resolve dgrad/wgrad split-acc bools
Fwd->>BwdArgs: dgrad_use_split_accumulator, wgrad_use_split_accumulator (plain bool)
Note over Bwd,BwdArgs: Backward — no recipe object needed
Bwd->>BwdArgs: read dgrad_use_split_accumulator
Bwd->>BwdArgs: read wgrad_use_split_accumulator
Note over Bwd: column-SP FP8 cleanup (new)
Bwd->>Bwd: clear_tensor_data(grad_output) if column-SP and fp8
Note over UBReg,Compiler: Re-init path
UBReg->>Compiler: destroy_ub() → torch.compiler.reset()
Note over Compiler: All baked assume_constant_result caches invalidated
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant Fwd as Module.forward
participant UBReg as UB Registry
participant Compiler as torch.compiler
participant BwdArgs as LinearBwdArgs
participant Bwd as _linear_backward
Note over Fwd,Compiler: UB query — compile-friendly path
Fwd->>Compiler: get_ub_is_fp8(name, use_fp8) [assume_constant_result]
Compiler-->>UBReg: get_ub(name, use_fp8).is_fp8_ubuf()
UBReg-->>Fwd: bool (baked as compile-time constant)
Note over Fwd,BwdArgs: Split-accumulator resolution (new)
Fwd->>Fwd: get_fp8_recipe() → resolve dgrad/wgrad split-acc bools
Fwd->>BwdArgs: dgrad_use_split_accumulator, wgrad_use_split_accumulator (plain bool)
Note over Bwd,BwdArgs: Backward — no recipe object needed
Bwd->>BwdArgs: read dgrad_use_split_accumulator
Bwd->>BwdArgs: read wgrad_use_split_accumulator
Note over Bwd: column-SP FP8 cleanup (new)
Bwd->>Bwd: clear_tensor_data(grad_output) if column-SP and fp8
Note over UBReg,Compiler: Re-init path
UBReg->>Compiler: destroy_ub() → torch.compiler.reset()
Note over Compiler: All baked assume_constant_result caches invalidated
Reviews (2): Last reviewed commit: "Reset torch.compile state in destroy_ub ..." | Re-trigger Greptile |
| @torch.compiler.assume_constant_result | ||
| def get_ub_is_fp8(name: str, use_fp8: bool) -> bool: | ||
| """Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant.""" | ||
| return get_ub(name, use_fp8).is_fp8_ubuf() |
There was a problem hiding this comment.
assume_constant_result can become stale after destroy_ub() + re-init
@torch.compiler.assume_constant_result caches the return value per (name, use_fp8) argument pair for the lifetime of a compiled region. If destroy_ub() is called and UB communicators are re-initialized with different FP8 settings (e.g. in a test harness that re-creates the communicators), the cached is_fp8_ubuf() result would be silently stale until the next recompile. In production training this should not happen — UB is typically initialized once — but test suites that tear down and rebuild UB communicators between cases could observe incorrect fp8_output/fp8_grad flags without triggering a recompile.
|
/te-ci pytorch L1 |
…t_result get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a reset, destroy_ub + re-init with different FP8 settings would read stale values until recompile. Only affects in-memory caches, not disk. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit the no-roles warning, which graph-breaks under fullgraph=True. qfactory dispatches on role.tensor_type instead of a pre-baked string key. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Description
Small standalone fixes extracted from a larger torch.compile branch, going directly from main. Two independent changes: making Userbuffers pybind11 queries compile-friendly, and freeing quantized grad_output early for column-parallel SP. Plus a custom-recipe caching fix, a split-accumulator refactor, and a CI test hook-up.
Type of change
Changes
Checklist: