feat(zero): enable torch.func transforms on engine for ZeRO 0/1/2#8026
Open
roycho96 wants to merge 5 commits into
Open
feat(zero): enable torch.func transforms on engine for ZeRO 0/1/2#8026roycho96 wants to merge 5 commits into
roycho96 wants to merge 5 commits into
Conversation
torch.func.grad / grad_and_value / jacrev invoke autograd through torch.autograd.grad, which fires the engine's output-tensor hooks but intentionally bypasses engine.backward(). The prologue then raises on ZeRO-0 (the safety net for direct loss.backward() callers) and the epilogue indexes empty ZeRO-1/2 grad bucket bookkeeping that the transformed graph never populated. Parameters are not leaves under the transform, so per-param post-accumulate-grad hooks never fire. Detect the active functorch interpreter via torch._C._functorch.peek_interpreter_stack and short-circuit both hooks early. The existing safety net for non-functorch direct loss.backward() callers (deepspeedai#7665) is preserved. Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
…0/1/2 Compare each transform's output to a non-DeepSpeed baseline cloned from the same initialization so a future regression that silently zeros gradients fails the test. Includes a negative case that locks in the ZeRO-0 direct-loss.backward() safety net. Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 4a4bd2ad5a
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Follow-up to #7916 and #8023.
Makes
torch.func.grad / grad_and_value / jacrevandvmap(grad)work when called directly on a DeepSpeed engine for ZeRO 0/1/2.torch.func.grad(lambda x: engine(x))(x)torch.func.grad_and_value(lambda x: engine(x))(x)torch.func.jacrev(lambda x: engine(x))(x)torch.func.vmap(torch.func.grad(...))(x_batch)torch.func.vmap(lambda x: engine(x))(x_batch)engine.backward(loss)(regression)vmapalone runs only the forward graph so it never hit the broken backward hooks and already worked before this PR; included in the table for completeness.Usage:
ZeRO-3 hits a separate SIGSEGV from the same APIs and is tracked separately.
Test:
pytest tests/unit/v1/zero/test_zero_torch_func.py