Skip to content

Add AMD ROCm (MI300/MI325) support for examples/math/qwen3-8b-m2po-full#18

Open
ChangyiYang wants to merge 3 commits into
Infini-AI-Lab:mainfrom
Liquid4All:liquid-cross-platform-verify
Open

Add AMD ROCm (MI300/MI325) support for examples/math/qwen3-8b-m2po-full#18
ChangyiYang wants to merge 3 commits into
Infini-AI-Lab:mainfrom
Liquid4All:liquid-cross-platform-verify

Conversation

@ChangyiYang

Copy link
Copy Markdown

Brings examples/math/qwen3-8b-m2po-full up on AMD MI300/MI325 (ROCm / gfx942) and verifies its training dynamics match an NVIDIA H100 run of the same recipe.

Changes

  • ROCm Docker image docker/Dockerfile.rocm (+ docker/rocm/* helpers): FROM the version-matched lmsysorg/sglang:v0.5.12.post1-rocm720-mi30x; install AstraFlow's deps under a constraints file that protects the base image's ROCm torch/sglang; megatron-core/mbridge/torchdata via --no-deps (megatron-core's numpy<2 pin conflicts with the base's numpy 2.x but is fine at runtime); and a Triton-AMD flash-attn (FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE, no CK compile) for the trainer.
  • Platform detection (train_worker/platforms/__init__.py): recognize ROCm/HIP → CudaPlatform (was NVIDIA-name-only → UnknownPlatform).
  • torch.compile (train_worker/utils/functional/vocab_parallel.py): fall back to eager on ROCm where inductor codegen of the logprob reductions fails on gfx942 (override via ASTRAFLOW_FORCE_TORCH_COMPILE).
  • Launchers: examples/math/qwen3-8b-m2po-full/scripts/run_qwen3-8b-m2po-full_amd.sh (docker), examples/_common/build_astraflow_rocm.sh (enroot/pyxis), and H100 reference sbatch scripts.
  • Report: docs/notes/cross-platform-fix.md.

Why flash_attention_2 is required on ROCm (not just a perf choice)

The FSDP engine packs sequences and passes cu_seq_lens_q/k (fsdp_engine.py:1203) with attention_mask=None; only the flash_attention_2 path honors those boundaries. With an sdpa fallback, packed sub-sequences attend across boundaries, so the trainer's recomputed logprobs diverge from the rollout — importance weights explode and the M2PO policy gradient is corrupted (while task reward still looks plausible):

step-1 metric sdpa (broken) flash_attention_2 (fixed) H100 ref
importance_weight avg / max 0.41 / ~1.5e5 0.9996 / 6.6 1.0000 / 2.8
m2po_mean_m2 ~17 0.0083 0.0022
old_logp vs new_logp -0.30 vs -3.2 -0.219 vs -0.222 ≈ equal

After the fix, the AMD eval curve tracks H100 (math500 85→88, amc 51→60 over 30 steps) and response length grows normally (~2000, was stuck at 1364).

🤖 Generated with Claude Code

ChangyiYang and others added 3 commits June 16, 2026 04:25
Port the example to AMD ROCm via the version-matched SGLang ROCm base image
(lmsysorg/sglang:v0.5.12.post1-rocm720-mi30x):

- docker/Dockerfile.rocm: layer astraflow on the ROCm sglang image without
  clobbering the ROCm torch/sglang; install megatron-core/mbridge/torchdata
  --no-deps (bypass megatron's numpy<2 pin) and a Triton-AMD flash-attn
  (FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE, no CK compile) for the trainer.
- docker/rocm/{gen_constraints,prune_pyproject,install_tms_shim,build_in_container}.
- platforms/__init__.py: detect ROCm/HIP -> CudaPlatform (was NVIDIA-only).
- vocab_parallel.py: skip torch.compile on ROCm (inductor codegen fails on gfx942).
- examples/.../run_qwen3-8b-m2po-full_amd.sh: docker-run launcher (flash_attention_2,
  Triton-AMD, ckpt off); examples/_common/build_astraflow_rocm.sh: enroot/pyxis build.

flash_attention_2 is required (not optional) for correctness: the FSDP engine packs
sequences and passes cu_seqlens, honored only by the flash path; sdpa breaks packed
varlen attention and the recomputed logprobs diverge from the rollout.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Adds Liquid H100 1-node Slurm wrappers (smoke + 50-step) for verifying
the qwen3-8b-m2po-full recipe on NVIDIA, used as cross-platform baseline
for @Astraflow-AMD's ROCm flash-attn fix work.

H100 step-1 metrics with the upstream default flash_attention_2 path:
IW≈1.0000, IW_max=2.80, approx_kl=-0.0006, m2≈0.0022; eval/length grow
through 50 steps. AMD-fixed run matches these step-for-step (see report).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Root cause + diagnosis (sdpa breaks packed varlen attention -> logprob/IW
divergence), the Triton-AMD flash-attn fix, and the AMD-vs-H100 50-step
comparison (eval/seq_len/importance-weight).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant