Skip to content

Support specifying tokamax gmm tile sizes in MaxText#3779

Open
darisoy wants to merge 1 commit intomainfrom
darisoy-gmm-tile
Open

Support specifying tokamax gmm tile sizes in MaxText#3779
darisoy wants to merge 1 commit intomainfrom
darisoy-gmm-tile

Conversation

@darisoy
Copy link
Copy Markdown
Collaborator

@darisoy darisoy commented Apr 29, 2026

Description

This PR enables specifying the tile sizes for both the forward and backward passes of Tokamax GMM (ragged_dot) in MaxText, and updates the fallback behavior when autotuning is enabled.

Proposed Changes

  • New Autotune Flag (src/maxtext/configs/types.py): Exposes a new tokamax_gmm_autotune config flag to toggle Tokamax's auto-tuner for GMM.
  • Simplified Routing & Reused Tiling (src/maxtext/layers/moe.py & src/maxtext/models/deepseek_batchsplit_fp8.py): Instead of introducing new Tokamax-specific tiling flags, this PR reuses the existing tiling tuple parameter (configured via wi_tiling and wo_tiling flags). If autotuning is disabled (tokamax_gmm_autotune=False), it passes the existing tiling configurations directly via a custom PallasMosaicTpuRaggedDot implementation for the Forward (gmm_tiling), Backward DLHS (gmm_rhs_transpose_tiling), and Backward DRHS (tgmm_tiling) passes.
  • Heuristic Autotuning Fallback: When autotuning is enabled (tokamax_gmm_autotune=True), the GMM implementation now uses the heuristics fallback for autotuning cache misses.

FIXES: b/506157856

Tests

Verified by manual testing on a tpu7x-8 VM using the deepseek decoder block.

Custom Tiling Sizes Test

Running with tokamax_gmm_autotune=False to verify that the existing base tiling configurations are correctly passed to the Tokamax GMM implementation:

JAX_PLATFORMS=tpu,cpu PYTHONPATH=src python3 -m maxtext.trainers.pre_train.train \
    ... \
    use_tokamax_gmm=True \
    tokamax_gmm_autotune=False \
    wi_tile_fwd_batch_seq=128 \
    wi_tile_fwd_embed_dim=128 \
    wi_tile_fwd_mlp_dim=128 \
    wi_tile_dlhs_batch_seq=256 \
    wi_tile_dlhs_embed_dim=256 \
    wi_tile_dlhs_mlp_dim=256 \
    wi_tile_drhs_batch_seq=512 \
    wi_tile_drhs_embed_dim=512 \
    wi_tile_drhs_mlp_dim=512

See details in http://b/506157856

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant