[PyTorch] Add method for mcore to register wgrad accumulation hook#2886
[PyTorch] Add method for mcore to register wgrad accumulation hook#2886ksivaman wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Gao Deng <gdeng@nvidia.com>
Greptile SummaryThis PR adds Confidence Score: 5/5PR is safe to merge; the implementation correctly mirrors the base module's hook mechanism. Both code paths in backward_dw that complete wgrad computation now trigger the hooks (accumulate_into_main_grad path and regular path). The untyped No files require special attention. Important Files Changed
Sequence DiagramsequenceDiagram
participant mcore as Megatron-Core (DDP)
participant GL as GroupedLinear
participant WGS as WeightGradStore
mcore->>GL: register_wgrad_accumulation_and_reduce_hooks(hook)
note over GL: appended to wgrad_accumulation_and_reduce_hooks[]
note over GL: Forward pass
GL->>WGS: put(tensor_list, func)
note over GL: Main backward pass
GL-->>GL: skip_backward_post_hook (bypass AccumulateGrad)
note over mcore: After backward
mcore->>GL: backward_dw()
GL->>WGS: pop() [executes wgrad GEMM]
alt _accumulate_into_main_grad
GL->>GL: _trigger_wgrad_accumulation_and_reduce_hooks()
else assign .grad tensors
GL->>GL: _trigger_wgrad_accumulation_and_reduce_hooks()
end
GL->>mcore: hook() [reduce-scatter / param.grad=None]
Reviews (2): Last reviewed commit: "Merge branch 'main' into fix_delay_wgrad..." | Re-trigger Greptile |
| @@ -114,6 +114,7 @@ def __init__( | |||
| self.num_extra_inputs = 2 | |||
There was a problem hiding this comment.
The annotation list does not specify the element type, making static analysis and IDE tooling less helpful. Since the base class (TransformerEngineBaseModule) stores the same kind of callbacks, aligning on a typed annotation would be cleaner.
| self.num_extra_inputs = 2 | |
| self.wgrad_accumulation_and_reduce_hooks: list[Callable] = [] |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
/te-ci pytorch L0 |
Description
This PR implements this for sequential ops, and is already implemented in modules.
Type of change
Changes
Checklist: