Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/bin/bash

# Validates the Gemma3-4B pre-training pipeline using a pre-converted MaxText checkpoint.

# The flow of this script is as follows:
# 1. Run inference on the pre-converted checkpoint.
# 2. Run pre-training starting from the pre-converted checkpoint.
# 3. Run inference on the checkpoint produced by the pre-training run.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
# bash test_gemma3_to_mt.sh $RUN_ID
# bash test_gemma3.sh $RUN_ID


set -ex

run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
MODEL_NAME='llama3.1-70b'

# To convert the multimodal model, make sure the use_multimodal is set to be true
USE_MULTIMODAL=false

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items

# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
DATASET_PATH=gs://maxtext-dataset

# Step 1: Run inference on the original checkpoint converted from Hugging Face
if [ ${USE_MULTIMODAL} == true ]; then
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \
scan_layers=false use_multimodal=true \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\'
else
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
scan_layers=false prompt='I love to' attention=\'dot_product\'
fi

# Step 2: Run Pre-training on the converted checkpoint
# We can also run training by using the scanned converted checkpoint
# Note that scanned checkpoint helps with efficient training
python3 -m maxtext.trainers.pre_train.train \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/train \
dataset_path=${DATASET_PATH} tokenizer_type="huggingface" \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
max_target_length=8192 steps=5 async_checkpointing=false \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL}

# Step 3: Run inference on the checkpoint generated from the previous run
if [ ${USE_MULTIMODAL} == true ]; then
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \
scan_layers=false use_multimodal=true \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\'
else
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
scan_layers=false prompt='I love to' attention=\'dot_product\'
fi


31 changes: 31 additions & 0 deletions tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_hf.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash

# Converts a MaxText checkpoint to a Hugging Face model checkpoint.

# Usage:
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
# bash test_gemma3_to_hf.sh $RUN_ID $CHECKPOINT_PATH $USE_MULTIMODAL $SCAN_LAYERS

set -ex

run_id=$1
CKPT_PATH=$2
USE_MULTIMODAL=${3:-false}
SCAN_LAYERS=${4:-false}

MODEL_NAME='llama3.1-70b'
BASE_OUTPUT_DIRECTORY="gs://runner-maxtext-logs/${MODEL_NAME}"

if [ "${SCAN_LAYERS,,}" = "true" ]; then
scan_status="scanned"
else
scan_status="unscanned"
fi

python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL_NAME} \
tokenizer_type="huggingface" \
load_parameters_path=${CKPT_PATH} \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/${scan_status}/${run_id} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=$SCAN_LAYERS
71 changes: 71 additions & 0 deletions tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/bin/bash

# Converts Gemma3-4B HuggingFace checkpoint to MaxText format and validates logit correctness.

# The flow of this script is as follows:
# 1. Install PyTorch (CPU) required for checkpoint conversion.
# 2. Convert the HuggingFace checkpoint to MaxText format in both unscanned and scanned formats.
# 3. Run a forward pass logits check to verify the converted checkpoint matches the original HF model.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
# bash test_gemma3_to_mt.sh $RUN_ID - to convert the checkpoint and run logit check for non-multimodal version
# bash test_gemma3_to_mt.sh $RUN_ID true - to convert the checkpoint and run logit check for multimodal version


set -ex

run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
MODEL_NAME='llama3.1-70b'
HF_GOLDEN_MODEL='meta-llama/Llama-3.1-70B'

# To convert the multimodal model, make sure the use_multimodal is set to be true
USE_MULTIMODAL=${2:-false}

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you want to store scanned and unscanned checkpoints
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}/to_maxtext

# Step 1: Install torch
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test whether we still need it.


# Step 2: Convert the checkpoint from Hugging Face to make it compatible with MaxText

# Step 2.a: Convert to unscanned checkpoint (for inference)
python3 -m maxtext.checkpoint_conversion.to_maxtext \
model_name=${MODEL_NAME} \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=false \
hardware=cpu skip_jax_distributed_system=True \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False

UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id}/0/items
echo "Unscanned checkpoint path: ${UNSCANNED_CKPT_PATH}"

# Step 2.b: Convert to scanned checkpoint (for training)
python3 -m maxtext.checkpoint_conversion.to_maxtext \
model_name=${MODEL_NAME} \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=true \
hardware=cpu skip_jax_distributed_system=True \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False

SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id}/0/items
echo "Scanned checkpoint path: ${SCANNED_CKPT_PATH}"

# Step 3: Test whether the forward pass logits match the original HF model
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
# ToDo: improve forward_pass_logit_checker to test multi-modal prompt
if [ "${USE_MULTIMODAL}" = "false" ]; then
python3 -m tests.utils.forward_pass_logit_checker \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
model_name=${MODEL_NAME} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=false \
--hf_model_path=${HF_GOLDEN_MODEL} \
--max_kl_div=0.03 \
--run_hf_model=true \
hardware=cpu skip_jax_distributed_system=True
fi