Add AMD ROCm (MI300/MI325) support for examples/math/qwen3-8b-m2po-full#18
Open
ChangyiYang wants to merge 3 commits into
Open
Add AMD ROCm (MI300/MI325) support for examples/math/qwen3-8b-m2po-full#18ChangyiYang wants to merge 3 commits into
ChangyiYang wants to merge 3 commits into
Conversation
Port the example to AMD ROCm via the version-matched SGLang ROCm base image
(lmsysorg/sglang:v0.5.12.post1-rocm720-mi30x):
- docker/Dockerfile.rocm: layer astraflow on the ROCm sglang image without
clobbering the ROCm torch/sglang; install megatron-core/mbridge/torchdata
--no-deps (bypass megatron's numpy<2 pin) and a Triton-AMD flash-attn
(FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE, no CK compile) for the trainer.
- docker/rocm/{gen_constraints,prune_pyproject,install_tms_shim,build_in_container}.
- platforms/__init__.py: detect ROCm/HIP -> CudaPlatform (was NVIDIA-only).
- vocab_parallel.py: skip torch.compile on ROCm (inductor codegen fails on gfx942).
- examples/.../run_qwen3-8b-m2po-full_amd.sh: docker-run launcher (flash_attention_2,
Triton-AMD, ckpt off); examples/_common/build_astraflow_rocm.sh: enroot/pyxis build.
flash_attention_2 is required (not optional) for correctness: the FSDP engine packs
sequences and passes cu_seqlens, honored only by the flash path; sdpa breaks packed
varlen attention and the recomputed logprobs diverge from the rollout.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Adds Liquid H100 1-node Slurm wrappers (smoke + 50-step) for verifying the qwen3-8b-m2po-full recipe on NVIDIA, used as cross-platform baseline for @Astraflow-AMD's ROCm flash-attn fix work. H100 step-1 metrics with the upstream default flash_attention_2 path: IW≈1.0000, IW_max=2.80, approx_kl=-0.0006, m2≈0.0022; eval/length grow through 50 steps. AMD-fixed run matches these step-for-step (see report). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Root cause + diagnosis (sdpa breaks packed varlen attention -> logprob/IW divergence), the Triton-AMD flash-attn fix, and the AMD-vs-H100 50-step comparison (eval/seq_len/importance-weight). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Brings
examples/math/qwen3-8b-m2po-fullup on AMD MI300/MI325 (ROCm / gfx942) and verifies its training dynamics match an NVIDIA H100 run of the same recipe.Changes
docker/Dockerfile.rocm(+docker/rocm/*helpers):FROMthe version-matchedlmsysorg/sglang:v0.5.12.post1-rocm720-mi30x; install AstraFlow's deps under a constraints file that protects the base image's ROCmtorch/sglang;megatron-core/mbridge/torchdatavia--no-deps(megatron-core'snumpy<2pin conflicts with the base's numpy 2.x but is fine at runtime); and a Triton-AMD flash-attn (FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE, no CK compile) for the trainer.train_worker/platforms/__init__.py): recognize ROCm/HIP →CudaPlatform(was NVIDIA-name-only →UnknownPlatform).train_worker/utils/functional/vocab_parallel.py): fall back to eager on ROCm where inductor codegen of the logprob reductions fails on gfx942 (override viaASTRAFLOW_FORCE_TORCH_COMPILE).examples/math/qwen3-8b-m2po-full/scripts/run_qwen3-8b-m2po-full_amd.sh(docker),examples/_common/build_astraflow_rocm.sh(enroot/pyxis), and H100 reference sbatch scripts.docs/notes/cross-platform-fix.md.Why
flash_attention_2is required on ROCm (not just a perf choice)The FSDP engine packs sequences and passes
cu_seq_lens_q/k(fsdp_engine.py:1203) withattention_mask=None; only theflash_attention_2path honors those boundaries. With ansdpafallback, packed sub-sequences attend across boundaries, so the trainer's recomputed logprobs diverge from the rollout — importance weights explode and the M2PO policy gradient is corrupted (while task reward still looks plausible):sdpa(broken)flash_attention_2(fixed)After the fix, the AMD eval curve tracks H100 (math500 85→88, amc 51→60 over 30 steps) and response length grows normally (~2000, was stuck at 1364).
🤖 Generated with Claude Code