Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions tests/end_to_end/tpu/gemma4/26b/test_gemma4.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/bin/bash

# Validates the Gemma4-26B pre-training pipeline using a pre-converted MaxText checkpoint.

# The flow of this script is as follows:
# 1. Run inference on the pre-converted checkpoint.
# 2. Run pre-training starting from the pre-converted checkpoint.
# 3. Run inference on the checkpoint produced by the pre-training run.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
# bash test_gemma4_to_mt.sh $RUN_ID
# bash test_gemma4.sh $RUN_ID


set -ex

run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
MODEL_NAME='gemma4-26b'

# To convert the multimodal model, make sure the use_multimodal is set to be true
USE_MULTIMODAL=false

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items

# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
DATASET_PATH=gs://maxtext-dataset

# Step 1: Run inference on the original checkpoint converted from Hugging Face
if [ ${USE_MULTIMODAL} == true ]; then
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \
scan_layers=false use_multimodal=true \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\'
else
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
scan_layers=false prompt='I love to' attention=\'dot_product\'
fi

# Step 2: Run Pre-training on the converted checkpoint
# We can also run training by using the scanned converted checkpoint
# Note that scanned checkpoint helps with efficient training
python3 -m maxtext.trainers.pre_train.train \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/train \
dataset_path=${DATASET_PATH} tokenizer_type="huggingface" \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
max_target_length=8192 steps=5 async_checkpointing=false \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL}

# Step 3: Run inference on the checkpoint generated from the previous run
if [ ${USE_MULTIMODAL} == true ]; then
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \
scan_layers=false use_multimodal=true \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\'
else
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
scan_layers=false prompt='I love to' attention=\'dot_product\'
fi


57 changes: 57 additions & 0 deletions tests/end_to_end/tpu/gemma4/26b/test_gemma4_rl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash

# Validates the Gemma4-26B RL pipeline using a pre-converted MaxText checkpoint.

# The flow of this script is as follows:
# 1. Run inference on the pre-converted checkpoint.
# 2. Run RL starting from the pre-converted checkpoint.
# 3. Run inference on the checkpoint produced by the RL run.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
# bash test_gemma4_to_mt.sh $RUN_ID
# bash test_gemma4_rl.sh $RUN_ID


set -ex

run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
use_pathways=${2:-false}
MODEL_NAME='gemma4-26b'

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items

# Step 1: Run inference on the original checkpoint converted from Hugging Face
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
hbm_utilization_vllm=0.5 \
prompt='Suggest some famous landmarks in London.' \
use_chat_template=True scan_layers=false enable_single_controller=${use_pathways}

# Step 2: Run RL on the converted checkpoint
python3 -m maxtext.trainers.post_train.rl.train_rl \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/rl \
load_parameters_path=${SCANNED_CKPT_PATH} \
run_name=${run_id} rl.loss_algo='grpo' scan_layers=true \
num_batches=5 batch_size=1 num_test_batches=5 \
model_name=${MODEL_NAME} enable_single_controller=${use_pathways} \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
rollout_tensor_parallelism=1 \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
vllm_additional_config='{"maxtext_config": {"model_name": "gemma4-26b", "log_config": "false"}}'


# Step 3: Run inference on the checkpoint generated from the previous run
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
hbm_utilization_vllm=0.5 \
prompt='Suggest some famous landmarks in London.' \
use_chat_template=True scan_layers=true enable_single_controller=${use_pathways}
55 changes: 55 additions & 0 deletions tests/end_to_end/tpu/gemma4/26b/test_gemma4_sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/bin/bash

# Validates the Gemma4-26B SFT pipeline using a pre-converted MaxText checkpoint.

# The flow of this script is as follows:
# 1. Run inference on the pre-converted checkpoint.
# 2. Run SFT starting from the pre-converted checkpoint.
# 3. Run inference on the checkpoint produced by the SFT run.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
# bash test_gemma4_to_mt.sh $RUN_ID
# bash test_gemma4_sft.sh $RUN_ID


set -ex

run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
use_pathways=${2:-false}
MODEL_NAME='gemma4-26b'

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items

# Step 1: Run inference on the original checkpoint converted from Hugging Face
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
hbm_utilization_vllm=0.5 \
prompt="Suggest some famous landmarks in London." \
use_chat_template=True scan_layers=false enable_single_controller=${use_pathways}

# Step 2: Run SFT on the converted checkpoint
python3 -m maxtext.trainers.post_train.sft.train_sft \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/sft \
load_parameters_path=${SCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
steps=5 scan_layers=true \
model_name=${MODEL_NAME} enable_single_controller=${use_pathways} \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False

# Step 3: Run inference on the checkpoint generated from the previous run
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
hbm_utilization_vllm=0.5 \
prompt="Suggest some famous landmarks in London." \
use_chat_template=True scan_layers=true enable_single_controller=${use_pathways}


108 changes: 18 additions & 90 deletions tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_hf.sh
Original file line number Diff line number Diff line change
@@ -1,103 +1,31 @@
#!/bin/bash

# This script is both an end-to-end test and documentation for converting a
# Gemma4-26B MaxText checkpoint to Hugging Face format. Can be run on a v5p-8.
# Converts a MaxText checkpoint to a Hugging Face model checkpoint.

# The flow of this script is as follows:
# 1. Convert a MaxText checkpoint to a Hugging Face model checkpoint.
# 2. Run a forward pass check to compare the logits and KL divergence between
# the converted checkpoint and the original HF model.

# Pre-requisites:
# 1. Set HF_TOKEN environment variable to your Hugging Face access token.
# export HF_TOKEN=<Hugging Face access token>
# 2. Provide a MaxText-format Gemma4-26B checkpoint via CKPT_PATH.
# One can be produced with tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh.
# Usage:
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
# bash test_gemma4_to_hf.sh $RUN_ID $CHECKPOINT_PATH $USE_MULTIMODAL $SCAN_LAYERS

set -ex
idx=$(date +%Y-%m-%d-%H-%M)
MODEL_NAME='gemma4-26b'
export MODEL_VARIATION='26b-it'
# To convert the multimodal model, set USE_MULTIMODAL=true
USE_MULTIMODAL=false
# Set USE_SCAN_LAYERS=true if the checkpoint was trained with scanned layers
USE_SCAN_LAYERS=true

# Installing torch for deps in forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# Non-Googlers: point MODEL_BUCKET to a GCS bucket you own.
export MODEL_BUCKET=gs://maxtext-gemma/gemma4
# Path to a pre-existing MaxText checkpoint for gemma4-26b. Must match USE_SCAN_LAYERS.
# Run tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh to produce one.
export CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/converted/unscanned/0/items
run_id=$1
CKPT_PATH=$2
USE_MULTIMODAL=${3:-false}
SCAN_LAYERS=${4:-false}

# Path to the original HF model weights for logit comparison.
export HF_MODEL=google/gemma-4-26b-a4b-it
MODEL_NAME='gemma4-26b'
BASE_OUTPUT_DIRECTORY="gs://runner-maxtext-logs/${MODEL_NAME}"

export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx}
if [ "${SCAN_LAYERS,,}" = "true" ]; then
scan_status="scanned"
else
scan_status="unscanned"
fi

python3 -m maxtext.checkpoint_conversion.to_huggingface \
"${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
model_name=${MODEL_NAME} \
hf_access_token=${HF_TOKEN} \
tokenizer_type="huggingface" \
load_parameters_path=${CKPT_PATH} \
base_output_directory=${LOCAL_PATH} \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/${scan_status}/${run_id} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=${USE_SCAN_LAYERS}

# Run forward pass logit checker to validate the converted checkpoint.
# The *_tile_fwd_*_dim flags are for reducing vmem usage to fit into v5p chips,
# not for performance purpose.
if [ "${USE_MULTIMODAL}" == true ]; then
TEST_PROMPT='Describe image <|image|>'
TEST_IMAGE='tests/assets/test_image.jpg'
export GOLDEN_LOGITS_PATH=/tmp/golden_gemma4_26b_vision.pickle

python3 -m tests.assets.logits_generation.generate_hf_golden_logits \
--model-id=${HF_MODEL} \
--output-path=${GOLDEN_LOGITS_PATH} \
--prompts="${TEST_PROMPT}" \
--image-paths=${TEST_IMAGE} \
--hf-model-path=${LOCAL_PATH} \
--output-format=pickle

echo "=== Running MaxText Forward Pass Logit Checker ==="
python3 -m tests.utils.forward_pass_logit_checker \
"${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
tokenizer_path=${HF_MODEL} \
load_parameters_path=${CKPT_PATH} \
model_name=${MODEL_NAME} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=${USE_SCAN_LAYERS} \
dtype=float32 \
wi_tile_fwd_embed_dim=512 \
wi_tile_fwd_mlp_dim=512 \
wo_tile_fwd_embed_dim=512 \
wo_tile_fwd_mlp_dim=512 \
matmul_precision=highest \
per_device_batch_size=1 \
attention=dot_product \
prompt="${TEST_PROMPT}" \
image_path=${TEST_IMAGE} \
--max_kl_div=0.1 \
--golden_logits_path=${GOLDEN_LOGITS_PATH}
else
echo "=== Running MaxText Forward Pass Logit Checker ==="
python3 -m tests.utils.forward_pass_logit_checker \
"${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
tokenizer_path=${HF_MODEL} \
load_parameters_path=${CKPT_PATH} \
model_name=${MODEL_NAME} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=${USE_SCAN_LAYERS} \
per_device_batch_size=1 \
dtype=float32 \
wi_tile_fwd_embed_dim=512 \
wi_tile_fwd_mlp_dim=512 \
wo_tile_fwd_embed_dim=512 \
wo_tile_fwd_mlp_dim=512 \
--max_kl_div=0.1 \
--run_hf_model=true \
--hf_model_path=${LOCAL_PATH}
fi
scan_layers=$SCAN_LAYERS
74 changes: 74 additions & 0 deletions tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_mt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/bin/bash

# Converts Gemma4-26B HuggingFace checkpoint to MaxText format and validates logit correctness.

# The flow of this script is as follows:
# 1. Install PyTorch (CPU) required for checkpoint conversion.
# 2. Convert the HuggingFace checkpoint to MaxText format in both unscanned and scanned formats.
# 3. Run a forward pass logits check to verify the converted checkpoint matches the original HF model.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
# bash test_gemma4_to_mt.sh $RUN_ID - to convert the checkpoint and run logit check for non-multimodal version
# bash test_gemma4_to_mt.sh $RUN_ID true - to convert the checkpoint and run logit check for multimodal version


set -ex

run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
MODEL_NAME='gemma4-26b'
HF_GOLDEN_MODEL='google/gemma-4-26b-a4b-it'

# To convert the multimodal model, make sure the use_multimodal is set to be true
USE_MULTIMODAL=${2:-false}

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you want to store scanned and unscanned checkpoints
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}/to_maxtext

# Step 1: Install torch

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test whether we still need it.

@chiajunglien chiajunglien Jun 26, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the command below to test:

xpk workload create \
  --cluster=mesa-v6e32-eu \
  --workload=gemma4-logit-check-$(date +%m%d%H%M) \
  --device-type=v6e-32 \
  --num-slices=1 \
  --priority=high \
  --project=cienet-cmcs \
  --zone=europe-west4-a \
  --skip-validation \
  --docker-image=gcr.io/tpu-prod-env-multipod/maxtext_jax_nightly:28210168847 \
  --command="set -xue; \
    python3 -m tests.utils.forward_pass_logit_checker \
      load_parameters_path='gs://us-central1-emmalien-test-83df3ecd-bucket/gemma4-26b-0625/unscanned/2026-06-25-06-21-12/0/items' \
      model_name=gemma4-26b \
      use_multimodal=false \
      scan_layers=false \
      --hf_model_path=google/gemma-4-26b-a4b-it \
      --max_kl_div=0.03 \
      --run_hf_model=true \
      hardware=cpu \
      skip_jax_distributed_system=True"

and got the error:

[transformers] PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
[transformers] DeepseekV32Config got `key=rope_scaling` in kwargs but hasn't set it as attribute. For RoPE standardization you need to set `self.rope_parameters` in model's config. 
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/tests/utils/forward_pass_logit_checker.py", line 50, in <module>
    from maxtext.checkpoint_conversion.utils.hf_utils import convert_jax_weight_to_torch
  File "/deps/src/maxtext/checkpoint_conversion/utils/hf_utils.py", line 25, in <module>
    import torch.nn.functional as F
ModuleNotFoundError: No module named 'torch'
XPK End: Fri Jun 26 09:08:08 UTC 2026
EXIT_CODE=1

Seems it still need torch to be installed.

python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# Step 2: Convert the checkpoint from Hugging Face to make it compatible with MaxText

# Step 2.a: Convert to unscanned checkpoint (for inference)
python3 -m maxtext.checkpoint_conversion.to_maxtext \
model_name=${MODEL_NAME} \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=false \
hardware=cpu skip_jax_distributed_system=True \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False

UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id}/0/items
echo "Unscanned checkpoint path: ${UNSCANNED_CKPT_PATH}"

# Step 2.b: Convert to scanned checkpoint (for training)
python3 -m maxtext.checkpoint_conversion.to_maxtext \
model_name=${MODEL_NAME} \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=true \
hardware=cpu skip_jax_distributed_system=True \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False

SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id}/0/items
echo "Scanned checkpoint path: ${SCANNED_CKPT_PATH}"

# Step 3: Test whether the forward pass logits match the original HF model
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
# ToDo: improve forward_pass_logit_checker to test multi-modal prompt
if [ "${USE_MULTIMODAL}" = "false" ]; then
python3 -m tests.utils.forward_pass_logit_checker \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
model_name=${MODEL_NAME} \
use_multimodal=${USE_MULTIMODAL} \
per_device_batch_size=1 \
dtype=float32 \
attention=dot_product \
scan_layers=false \
--hf_model_path=${HF_GOLDEN_MODEL} \
--max_kl_div=0.03 \
--run_hf_model=true \
hardware=cpu skip_jax_distributed_system=True
fi
Loading