From f903e5f3f7d635337b79ba05522dc09f7677b7d7 Mon Sep 17 00:00:00 2001 From: Emma Lien Date: Fri, 26 Jun 2026 06:14:14 +0000 Subject: [PATCH 1/2] feat(testing): add end-to-end training and inference tests for llama3.1-70b Introduces E2E test configurations and scripts for llama3.1-70b model, covering both inference decoding and pre-training validation pipelines. --- .../tpu/llama3.1/70b/test_llama3.1_70b.sh | 84 +++++++++++++++++++ .../llama3.1/70b/test_llama3.1_70b_to_hf.sh | 31 +++++++ .../llama3.1/70b/test_llama3.1_70b_to_mt.sh | 75 +++++++++++++++++ 3 files changed, 190 insertions(+) create mode 100644 tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b.sh create mode 100644 tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_hf.sh create mode 100644 tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh diff --git a/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b.sh b/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b.sh new file mode 100644 index 0000000000..b6b64ecf54 --- /dev/null +++ b/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +# Validates the Gemma3-4B pre-training pipeline using a pre-converted MaxText checkpoint. + +# The flow of this script is as follows: +# 1. Run inference on the pre-converted checkpoint. +# 2. Run pre-training starting from the pre-converted checkpoint. +# 3. Run inference on the checkpoint produced by the pre-training run. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) +# bash test_gemma3_to_mt.sh $RUN_ID +# bash test_gemma3.sh $RUN_ID + + +set -ex + +run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} +MODEL_NAME='llama3.1-70b' + +# To convert the multimodal model, make sure the use_multimodal is set to be true +USE_MULTIMODAL=false + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items + +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +DATASET_PATH=gs://maxtext-dataset + +# Step 1: Run inference on the original checkpoint converted from Hugging Face +if [ ${USE_MULTIMODAL} == true ]; then + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ + scan_layers=false use_multimodal=true \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' +else + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + scan_layers=false prompt='I love to' attention=\'dot_product\' +fi + +# Step 2: Run Pre-training on the converted checkpoint +# We can also run training by using the scanned converted checkpoint +# Note that scanned checkpoint helps with efficient training +python3 -m maxtext.trainers.pre_train.train \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/train \ + dataset_path=${DATASET_PATH} tokenizer_type="huggingface" \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + max_target_length=8192 steps=5 async_checkpointing=false \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL} + +# Step 3: Run inference on the checkpoint generated from the previous run +if [ ${USE_MULTIMODAL} == true ]; then + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ + scan_layers=false use_multimodal=true \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' +else + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + scan_layers=false prompt='I love to' attention=\'dot_product\' +fi + + diff --git a/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_hf.sh b/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_hf.sh new file mode 100644 index 0000000000..1f115bb24a --- /dev/null +++ b/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_hf.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Converts a MaxText checkpoint to a Hugging Face model checkpoint. + +# Usage: +# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) +# bash test_gemma3_to_hf.sh $RUN_ID $CHECKPOINT_PATH $USE_MULTIMODAL $SCAN_LAYERS + +set -ex + +run_id=$1 +CKPT_PATH=$2 +USE_MULTIMODAL=${3:-false} +SCAN_LAYERS=${4:-false} + +MODEL_NAME='llama3.1-70b' +BASE_OUTPUT_DIRECTORY="gs://runner-maxtext-logs/${MODEL_NAME}" + +if [ "${SCAN_LAYERS,,}" = "true" ]; then + scan_status="scanned" +else + scan_status="unscanned" +fi + +python3 -m maxtext.checkpoint_conversion.to_huggingface \ + model_name=${MODEL_NAME} \ + tokenizer_type="huggingface" \ + load_parameters_path=${CKPT_PATH} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/${scan_status}/${run_id} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=$SCAN_LAYERS diff --git a/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh b/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh new file mode 100644 index 0000000000..aa521aa57e --- /dev/null +++ b/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +# Converts Gemma3-4B HuggingFace checkpoint to MaxText format and validates logit correctness. + +# The flow of this script is as follows: +# 1. Install PyTorch (CPU) required for checkpoint conversion. +# 2. Convert the HuggingFace checkpoint to MaxText format in both unscanned and scanned formats. +# 3. Run a forward pass logits check to verify the converted checkpoint matches the original HF model. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) +# bash test_gemma3_to_mt.sh $RUN_ID - to convert the checkpoint and run logit check for non-multimodal version +# bash test_gemma3_to_mt.sh $RUN_ID true - to convert the checkpoint and run logit check for multimodal version + + +set -ex + +run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} +MODEL_NAME='llama3.1-70b' +HF_GOLDEN_MODEL='meta-llama/Llama-3.1-70B' + +# To convert the multimodal model, make sure the use_multimodal is set to be true +USE_MULTIMODAL=${2:-false} + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you want to store scanned and unscanned checkpoints +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}/to_maxtext + +# Step 1: Install torch +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Step 2: Convert the checkpoint from Hugging Face to make it compatible with MaxText + +# Step 2.a: Convert to unscanned checkpoint (for inference) +python3 -m maxtext.checkpoint_conversion.to_maxtext \ + model_name=${MODEL_NAME} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=false \ + hardware=cpu skip_jax_distributed_system=True \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + --lazy_load_tensors=False \ + --eager_load_method='transformers' + +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id}/0/items +echo "Unscanned checkpoint path: ${UNSCANNED_CKPT_PATH}" + +# Step 2.b: Convert to scanned checkpoint (for training) +python3 -m maxtext.checkpoint_conversion.to_maxtext \ + model_name=${MODEL_NAME} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=true \ + hardware=cpu skip_jax_distributed_system=True \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + --lazy_load_tensors=False \ + --eager_load_method='transformers' + +SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id}/0/items +echo "Scanned checkpoint path: ${SCANNED_CKPT_PATH}" + +# Step 3: Test whether the forward pass logits match the original HF model +# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` +# ToDo: improve forward_pass_logit_checker to test multi-modal prompt +if [ "${USE_MULTIMODAL}" = "false" ]; then + python3 -m tests.utils.forward_pass_logit_checker \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + model_name=${MODEL_NAME} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=false \ + --hf_model_path=${HF_GOLDEN_MODEL} \ + --max_kl_div=0.03 \ + --run_hf_model=true \ + hardware=cpu skip_jax_distributed_system=True +fi From df41cbd8501c406d7a659a6f7870b5c2dc72727a Mon Sep 17 00:00:00 2001 From: Emma Lien Date: Fri, 26 Jun 2026 08:49:08 +0000 Subject: [PATCH 2/2] fix --- .../tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh b/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh index aa521aa57e..09cd605ce5 100644 --- a/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh +++ b/tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh @@ -38,9 +38,7 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \ use_multimodal=${USE_MULTIMODAL} \ scan_layers=false \ hardware=cpu skip_jax_distributed_system=True \ - checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ - --lazy_load_tensors=False \ - --eager_load_method='transformers' + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id}/0/items echo "Unscanned checkpoint path: ${UNSCANNED_CKPT_PATH}" @@ -52,9 +50,7 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \ use_multimodal=${USE_MULTIMODAL} \ scan_layers=true \ hardware=cpu skip_jax_distributed_system=True \ - checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ - --lazy_load_tensors=False \ - --eager_load_method='transformers' + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id}/0/items echo "Scanned checkpoint path: ${SCANNED_CKPT_PATH}"