Skip to content

AMD/ROCm support for qwen3-8b-m2po-full + H100 cross-platform verification#19

Closed
ChangyiYang wants to merge 2 commits into
Infini-AI-Lab:mainfrom
ChangyiYang:liquid-cross-platform-verify
Closed

AMD/ROCm support for qwen3-8b-m2po-full + H100 cross-platform verification#19
ChangyiYang wants to merge 2 commits into
Infini-AI-Lab:mainfrom
ChangyiYang:liquid-cross-platform-verify

Conversation

@ChangyiYang

Copy link
Copy Markdown

Brings the examples/math/qwen3-8b-m2po-full recipe up on AMD MI300/MI325
(ROCm / gfx942) and verifies training dynamics match the known-good
NVIDIA H100 run.

Summary

  • ROCm Docker image (docker/Dockerfile.rocm + docker/rocm/*)
    layered on the official SGLang ROCm image, with constraints to keep
    pip from pulling CUDA wheels over the base image's torch/sglang.
  • Code adaptations for ROCm:
    • astraflow/train_worker/platforms/__init__.py: detect ROCm via
      torch.version.hip and return CudaPlatform (ROCm exposes AMD GPUs
      through torch.cuda).
    • astraflow/train_worker/utils/functional/vocab_parallel.py: fall
      back to eager for _gather_logprobs* on ROCm (inductor codegen of
      those reductions crashes on gfx942 / torch 2.9 with a masked
      InductorError). Override via ASTRAFLOW_FORCE_TORCH_COMPILE.
  • Triton-AMD flash-attn for the trainer
    (FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE pip install flash-attn==2.8.3 --no-build-isolation). The recipe's default
    attn_impl=flash_attention_2 works as-is — no sdpa override.
  • Launchers:
    • examples/math/qwen3-8b-m2po-full/scripts/run_qwen3-8b-m2po-full_amd.sh
      (docker-run wrapper)
    • examples/math/qwen3-8b-m2po-full/scripts/run_h100_smoke.slurm and
      run_h100_50step.slurm (NVIDIA reference)
    • examples/_common/build_astraflow_rocm.sh for enroot/pyxis clusters.

Why the flash-attn detail matters

AstraFlow's FSDP engine packs multiple sequences into one microbatch and
passes per-sequence boundaries via cu_seq_lens_q/k with
attention_mask=None (fsdp_engine.py:1203). Only transformers'
flash_attention_2 path consumes those cu_seqlens to keep attention
inside each sub-sequence. Under sdpa (the obvious fallback when
flash-attn isn't installed), the kwargs are ignored and a single causal
mask spans the whole packed buffer → packed sub-sequences attend across
boundaries → trainer-recomputed logits are systematically wrong.

At step 1 (same weights as the rollout, so old_logpnew_logp
should hold):

metric AMD broken (sdpa) AMD fixed (fa2) H100 ref
importance_weight/avg 0.41 0.9996 1.0000
importance_weight/max up to 1.5e5 6.6 2.8
approx_kl/avg −3.0 −0.0035 −0.0006
m2po_mean_m2 ~17 0.0083 0.0022

Even though reward (from the rollout itself) looked superficially fine,
the importance ratios fed into M2PO were garbage and the policy gradient
was effectively broken. On ROCm, flash_attention_2 is therefore a
correctness requirement, not a perf option, whenever the trainer packs
sequences.

Result: AMD now tracks H100

50-step side-by-side, identical config (eval freq 10, recover off):

eval avg@4 (math500 | amc | aime24 | minerva):

step AMD-fixed H100
10 85.5 / 51.4 / 26.7 / 40.6 85.7 / 51.7 / 29.8 / 40.9
20 87.1 / 55.7 / 32.1 / 39.4 86.6 / 54.7 / 32.5 / 41.4
30 87.9 / 60.7 / 36.5 / 42.2 86.7 / 55.6 / 28.5 / 41.4

Per-step (1–30) means:

run reward_mean seq_len start → end IW_avg
AMD-fixed 0.577 2142 → 1930 1.0000
H100 0.534 2168 → 2152 1.0000
AMD-broken (sdpa) 0.641 1364 → 1364 0.626

Both runs sit on essentially the same eval curve (math500 85→88,
amc 51→60, aime24 27→36); the broken AMD run is clearly separable
(IW 0.63, length stuck at 1364, eval flat).

Full writeup with environment/version matrix and run links:
docs/notes/cross-platform-fix.md.

W&B runs (project liquid-ai/astraflow-math):

  • AMD fixed — qwen3-8b-m2po-full-model0_373f28d2
  • H100 ref — qwen3-8b-m2po-full-model0_b2a99e8f
  • AMD broken (for contrast) — qwen3-8b-m2po-full-model0_9038f5e4

Test plan

  • Single-node 8-GPU AMD MI325 run reproduces qwen3-8b-m2po-full
    training and matches H100 reference within noise (above).
  • Step-1 sanity (old_logp ≈ new_logp, IW ≈ 1) verified on both
    backends.
  • H100 8×H100 reference run completes full 50 steps with the
    upstream default flash_attention_2.

🤖 Generated with Claude Code

ChangyiYang and others added 2 commits June 16, 2026 04:25
Port the example to AMD ROCm via the version-matched SGLang ROCm base image
(lmsysorg/sglang:v0.5.12.post1-rocm720-mi30x):

- docker/Dockerfile.rocm: layer astraflow on the ROCm sglang image without
  clobbering the ROCm torch/sglang; install megatron-core/mbridge/torchdata
  --no-deps (bypass megatron's numpy<2 pin) and a Triton-AMD flash-attn
  (FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE, no CK compile) for the trainer.
- docker/rocm/{gen_constraints,prune_pyproject,install_tms_shim,build_in_container}.
- platforms/__init__.py: detect ROCm/HIP -> CudaPlatform (was NVIDIA-only).
- vocab_parallel.py: skip torch.compile on ROCm (inductor codegen fails on gfx942).
- examples/.../run_qwen3-8b-m2po-full_amd.sh: docker-run launcher (flash_attention_2,
  Triton-AMD, ckpt off); examples/_common/build_astraflow_rocm.sh: enroot/pyxis build.

flash_attention_2 is required (not optional) for correctness: the FSDP engine packs
sequences and passes cu_seqlens, honored only by the flash path; sdpa breaks packed
varlen attention and the recomputed logprobs diverge from the rollout.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Adds Liquid H100 1-node Slurm wrappers (smoke + 50-step) for verifying
the qwen3-8b-m2po-full recipe on NVIDIA, used as cross-platform baseline
for @Astraflow-AMD's ROCm flash-attn fix work.

H100 step-1 metrics with the upstream default flash_attention_2 path:
IW≈1.0000, IW_max=2.80, approx_kl=-0.0006, m2≈0.0022; eval/length grow
through 50 steps. AMD-fixed run matches these step-for-step (see report).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@ChangyiYang

Copy link
Copy Markdown
Author

Duplicate of #18 (@Astraflow-AMD opened from Liquid4All/astraflow first with the same branch). Closing to consolidate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant