QLoRA Support + NNX Decoder Sharding Fixes#3968
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
| """Test applying standard LoRA to model with scan_layers=True.""" | ||
| self._run_apply_lora_test(scan_layers=True) | ||
|
|
||
| @unittest.skip("Awaiting qwix fix for QLoRA params materialization") |
There was a problem hiding this comment.
Why are we skipping QLoRA tests? Does that mean QLoRA won't work even after merging this PR?
There was a problem hiding this comment.
Yes, we need a new qwix release for NNX model QLoRA support.
My two CLs addressing the underlying parameter materialization issues have already been merged upstream in qwix. I will sync with the team to ask for their next release plan. Once the new release comes out, we just need to update the dependency version here and unskip these tests.
Should we wait and update the dependency/unskip tests in this PR, or would you prefer to merge this PR as-is and do the update in a separate follow-up PR once the release is published?
There was a problem hiding this comment.
I would say let's wait for Qwix fix, update dependency in MAxText and then merge this PR.
There was a problem hiding this comment.
Will do. I'll leave this PR open and update the qwix dependency as soon as the new release is published.
c37faef to
4dde975
Compare
4dde975 to
6a12059
Compare
6e1a1e3 to
4d44567
Compare
4d44567 to
67eec07
Compare
Description
This PR introduces core support for QLoRA and implements robust sharding metadata synchronization for NNX decoders.
Currently, applying LoRA adapters alongside quantization and complex multi-host sharding setups can lead to PartitionSpec mismatches and cross-backend device_put issues during lax.scan.
This PR solves these issues by:
Future improvements will include removing the _safe_reshard workaround once the underlying JAX/Qwix parameter materialization issues are fully resolved upstream.
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.