Skip to content

[PyTorch] Add method for mcore to register wgrad accumulation hook#2886

Open
ksivaman wants to merge 3 commits intoNVIDIA:mainfrom
ksivaman:fix_delay_wgrad_mcore_integration
Open

[PyTorch] Add method for mcore to register wgrad accumulation hook#2886
ksivaman wants to merge 3 commits intoNVIDIA:mainfrom
ksivaman:fix_delay_wgrad_mcore_integration

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

This PR implements this for sequential ops, and is already implemented in modules.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add method for mcore to register wgrad accumulation hook.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Gao Deng <gdeng@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 April 15, 2026 18:08
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 15, 2026

Greptile Summary

This PR adds register_wgrad_accumulation_and_reduce_hooks to the GroupedLinear sequential op, mirroring the same mechanism already present in TransformerEngineBaseModule. Hooks are stored in a new wgrad_accumulation_and_reduce_hooks list and fired at the end of backward_dw (both the _accumulate_into_main_grad early-return path and the regular gradient-assignment path).

Confidence Score: 5/5

PR is safe to merge; the implementation correctly mirrors the base module's hook mechanism.

Both code paths in backward_dw that complete wgrad computation now trigger the hooks (accumulate_into_main_grad path and regular path). The untyped list annotation is a P2 style issue already raised in a prior thread. No logic bugs or data-integrity concerns were identified.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py Adds wgrad accumulation hook registration and triggering to GroupedLinear; both execution paths in backward_dw correctly call _trigger_wgrad_accumulation_and_reduce_hooks.

Sequence Diagram

sequenceDiagram
    participant mcore as Megatron-Core (DDP)
    participant GL as GroupedLinear
    participant WGS as WeightGradStore

    mcore->>GL: register_wgrad_accumulation_and_reduce_hooks(hook)
    note over GL: appended to wgrad_accumulation_and_reduce_hooks[]

    note over GL: Forward pass
    GL->>WGS: put(tensor_list, func)

    note over GL: Main backward pass
    GL-->>GL: skip_backward_post_hook (bypass AccumulateGrad)

    note over mcore: After backward
    mcore->>GL: backward_dw()
    GL->>WGS: pop()  [executes wgrad GEMM]
    alt _accumulate_into_main_grad
        GL->>GL: _trigger_wgrad_accumulation_and_reduce_hooks()
    else assign .grad tensors
        GL->>GL: _trigger_wgrad_accumulation_and_reduce_hooks()
    end
    GL->>mcore: hook() [reduce-scatter / param.grad=None]
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into fix_delay_wgrad..." | Re-trigger Greptile

@@ -114,6 +114,7 @@ def __init__(
self.num_extra_inputs = 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Untyped list annotation

The annotation list does not specify the element type, making static analysis and IDE tooling less helpful. Since the base class (TransformerEngineBaseModule) stores the same kind of callbacks, aligning on a typed annotation would be cleaner.

Suggested change
self.num_extra_inputs = 2
self.wgrad_accumulation_and_reduce_hooks: list[Callable] = []

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci pytorch L0

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants