AMD/ROCm support for qwen3-8b-m2po-full + H100 cross-platform verification#19
Closed
ChangyiYang wants to merge 2 commits into
Closed
AMD/ROCm support for qwen3-8b-m2po-full + H100 cross-platform verification#19ChangyiYang wants to merge 2 commits into
ChangyiYang wants to merge 2 commits into
Conversation
Port the example to AMD ROCm via the version-matched SGLang ROCm base image
(lmsysorg/sglang:v0.5.12.post1-rocm720-mi30x):
- docker/Dockerfile.rocm: layer astraflow on the ROCm sglang image without
clobbering the ROCm torch/sglang; install megatron-core/mbridge/torchdata
--no-deps (bypass megatron's numpy<2 pin) and a Triton-AMD flash-attn
(FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE, no CK compile) for the trainer.
- docker/rocm/{gen_constraints,prune_pyproject,install_tms_shim,build_in_container}.
- platforms/__init__.py: detect ROCm/HIP -> CudaPlatform (was NVIDIA-only).
- vocab_parallel.py: skip torch.compile on ROCm (inductor codegen fails on gfx942).
- examples/.../run_qwen3-8b-m2po-full_amd.sh: docker-run launcher (flash_attention_2,
Triton-AMD, ckpt off); examples/_common/build_astraflow_rocm.sh: enroot/pyxis build.
flash_attention_2 is required (not optional) for correctness: the FSDP engine packs
sequences and passes cu_seqlens, honored only by the flash path; sdpa breaks packed
varlen attention and the recomputed logprobs diverge from the rollout.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Adds Liquid H100 1-node Slurm wrappers (smoke + 50-step) for verifying the qwen3-8b-m2po-full recipe on NVIDIA, used as cross-platform baseline for @Astraflow-AMD's ROCm flash-attn fix work. H100 step-1 metrics with the upstream default flash_attention_2 path: IW≈1.0000, IW_max=2.80, approx_kl=-0.0006, m2≈0.0022; eval/length grow through 50 steps. AMD-fixed run matches these step-for-step (see report). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Author
|
Duplicate of #18 (@Astraflow-AMD opened from Liquid4All/astraflow first with the same branch). Closing to consolidate. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Brings the
examples/math/qwen3-8b-m2po-fullrecipe up on AMD MI300/MI325(ROCm / gfx942) and verifies training dynamics match the known-good
NVIDIA H100 run.
Summary
docker/Dockerfile.rocm+docker/rocm/*)layered on the official SGLang ROCm image, with constraints to keep
pip from pulling CUDA wheels over the base image's torch/sglang.
astraflow/train_worker/platforms/__init__.py: detect ROCm viatorch.version.hipand returnCudaPlatform(ROCm exposes AMD GPUsthrough
torch.cuda).astraflow/train_worker/utils/functional/vocab_parallel.py: fallback to eager for
_gather_logprobs*on ROCm (inductor codegen ofthose reductions crashes on gfx942 / torch 2.9 with a masked
InductorError). Override viaASTRAFLOW_FORCE_TORCH_COMPILE.(
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE pip install flash-attn==2.8.3 --no-build-isolation). The recipe's defaultattn_impl=flash_attention_2works as-is — nosdpaoverride.examples/math/qwen3-8b-m2po-full/scripts/run_qwen3-8b-m2po-full_amd.sh(docker-run wrapper)
examples/math/qwen3-8b-m2po-full/scripts/run_h100_smoke.slurmandrun_h100_50step.slurm(NVIDIA reference)examples/_common/build_astraflow_rocm.shfor enroot/pyxis clusters.Why the flash-attn detail matters
AstraFlow's FSDP engine packs multiple sequences into one microbatch and
passes per-sequence boundaries via
cu_seq_lens_q/kwithattention_mask=None(fsdp_engine.py:1203). Only transformers'flash_attention_2path consumes thosecu_seqlensto keep attentioninside each sub-sequence. Under
sdpa(the obvious fallback whenflash-attn isn't installed), the kwargs are ignored and a single causal
mask spans the whole packed buffer → packed sub-sequences attend across
boundaries → trainer-recomputed logits are systematically wrong.
At step 1 (same weights as the rollout, so
old_logp≈new_logpshould hold):
sdpa)fa2)importance_weight/avgimportance_weight/maxapprox_kl/avgm2po_mean_m2Even though reward (from the rollout itself) looked superficially fine,
the importance ratios fed into M2PO were garbage and the policy gradient
was effectively broken. On ROCm,
flash_attention_2is therefore acorrectness requirement, not a perf option, whenever the trainer packs
sequences.
Result: AMD now tracks H100
50-step side-by-side, identical config (
eval freq 10,recover off):eval avg@4 (math500 | amc | aime24 | minerva):
Per-step (1–30) means:
Both runs sit on essentially the same eval curve (math500 85→88,
amc 51→60, aime24 27→36); the broken AMD run is clearly separable
(IW 0.63, length stuck at 1364, eval flat).
Full writeup with environment/version matrix and run links:
docs/notes/cross-platform-fix.md.W&B runs (project
liquid-ai/astraflow-math):qwen3-8b-m2po-full-model0_373f28d2qwen3-8b-m2po-full-model0_b2a99e8fqwen3-8b-m2po-full-model0_9038f5e4Test plan
training and matches H100 reference within noise (above).
old_logp ≈ new_logp, IW ≈ 1) verified on bothbackends.
upstream default
flash_attention_2.🤖 Generated with Claude Code