fix(pt): enable second-order autograd for tabulate descriptors#5537
fix(pt): enable second-order autograd for tabulate descriptors#5537njzjz wants to merge 1 commit into
Conversation
📝 WalkthroughWalkthroughAdds second-order autograd support to all five PyTorch ChangesSecond-order autograd for all FusionSe* ops
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/op/pt/tabulate_multi_device.cc (1)
607-608: ⚡ Quick winRemove unused private member variable.
The
devicemember is declared but never used anywhere inTabulateFusionSeAGradOp. This appears to be leftover from an earlier implementation.class TabulateFusionSeAGradOp : public torch::autograd::Function<TabulateFusionSeAGradOp> { - private: - std::string device; - public:🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/op/pt/tabulate_multi_device.cc` around lines 607 - 608, Remove the unused private member variable `device` from the `TabulateFusionSeAGradOp` class since it is declared but never referenced anywhere in the implementation. Simply delete the line containing `std::string device;` from the private section.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@source/op/pt/tabulate_multi_device.cc`:
- Around line 607-608: Remove the unused private member variable `device` from
the `TabulateFusionSeAGradOp` class since it is declared but never referenced
anywhere in the implementation. Simply delete the line containing `std::string
device;` from the private section.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 10e5cae3-359e-4d0a-b565-f7fc90453f06
📒 Files selected for processing (6)
source/op/pt/tabulate_multi_device.ccsource/tests/pt/test_tabulate_fusion_se_a.pysource/tests/pt/test_tabulate_fusion_se_atten.pysource/tests/pt/test_tabulate_fusion_se_r.pysource/tests/pt/test_tabulate_fusion_se_t.pysource/tests/pt/test_tabulate_fusion_se_t_tebd.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5537 +/- ##
==========================================
+ Coverage 82.18% 82.20% +0.02%
==========================================
Files 890 890
Lines 101358 101616 +258
Branches 4240 4266 +26
==========================================
+ Hits 83301 83534 +233
- Misses 16756 16760 +4
- Partials 1301 1322 +21 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Summary
se_a,se_atten,se_t,se_r, andse_t_tebdbackward paths to existing grad-grad kernelsFixes #4994.
Tests
pytest source/tests/pt/test_tabulate_fusion_se_a.py source/tests/pt/test_tabulate_fusion_se_atten.py source/tests/pt/test_tabulate_fusion_se_r.py source/tests/pt/test_tabulate_fusion_se_t.py source/tests/pt/test_tabulate_fusion_se_t_tebd.py -qruff check .ruff format .clang-formatSummary by CodeRabbit
New Features
Tests