From 1ee31976a984267a90af1cc482c5608d1548b376 Mon Sep 17 00:00:00 2001 From: Emma Lien Date: Fri, 26 Jun 2026 05:55:07 +0000 Subject: [PATCH 1/3] feat(testing): add end-to-end training and inference tests for gemma4-26b Introduces E2E test configurations and scripts for gemma4-26b model, covering both inference decoding and pre-training validation pipelines. --- .../end_to_end/tpu/gemma4/26b/test_gemma4.sh | 84 ++++++++++++++ .../tpu/gemma4/26b/test_gemma4_to_hf.sh | 108 +++--------------- .../tpu/gemma4/26b/test_gemma4_to_mt.sh | 74 ++++++++++++ 3 files changed, 176 insertions(+), 90 deletions(-) create mode 100644 tests/end_to_end/tpu/gemma4/26b/test_gemma4.sh create mode 100644 tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_mt.sh diff --git a/tests/end_to_end/tpu/gemma4/26b/test_gemma4.sh b/tests/end_to_end/tpu/gemma4/26b/test_gemma4.sh new file mode 100644 index 0000000000..262505c7d5 --- /dev/null +++ b/tests/end_to_end/tpu/gemma4/26b/test_gemma4.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +# Validates the Gemma4-26B pre-training pipeline using a pre-converted MaxText checkpoint. + +# The flow of this script is as follows: +# 1. Run inference on the pre-converted checkpoint. +# 2. Run pre-training starting from the pre-converted checkpoint. +# 3. Run inference on the checkpoint produced by the pre-training run. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) +# bash test_gemma4_to_mt.sh $RUN_ID +# bash test_gemma4.sh $RUN_ID + + +set -ex + +run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} +MODEL_NAME='gemma4-26b' + +# To convert the multimodal model, make sure the use_multimodal is set to be true +USE_MULTIMODAL=false + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items + +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +DATASET_PATH=gs://maxtext-dataset + +# Step 1: Run inference on the original checkpoint converted from Hugging Face +if [ ${USE_MULTIMODAL} == true ]; then + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ + scan_layers=false use_multimodal=true \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' +else + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + scan_layers=false prompt='I love to' attention=\'dot_product\' +fi + +# Step 2: Run Pre-training on the converted checkpoint +# We can also run training by using the scanned converted checkpoint +# Note that scanned checkpoint helps with efficient training +python3 -m maxtext.trainers.pre_train.train \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/train \ + dataset_path=${DATASET_PATH} tokenizer_type="huggingface" \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + max_target_length=8192 steps=5 async_checkpointing=false \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL} + +# Step 3: Run inference on the checkpoint generated from the previous run +if [ ${USE_MULTIMODAL} == true ]; then + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ + scan_layers=false use_multimodal=true \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' +else + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + scan_layers=false prompt='I love to' attention=\'dot_product\' +fi + + diff --git a/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_hf.sh b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_hf.sh index 476aed964a..7d51483c72 100755 --- a/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_hf.sh +++ b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_hf.sh @@ -1,103 +1,31 @@ #!/bin/bash -# This script is both an end-to-end test and documentation for converting a -# Gemma4-26B MaxText checkpoint to Hugging Face format. Can be run on a v5p-8. +# Converts a MaxText checkpoint to a Hugging Face model checkpoint. -# The flow of this script is as follows: -# 1. Convert a MaxText checkpoint to a Hugging Face model checkpoint. -# 2. Run a forward pass check to compare the logits and KL divergence between -# the converted checkpoint and the original HF model. - -# Pre-requisites: -# 1. Set HF_TOKEN environment variable to your Hugging Face access token. -# export HF_TOKEN= -# 2. Provide a MaxText-format Gemma4-26B checkpoint via CKPT_PATH. -# One can be produced with tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh. +# Usage: +# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) +# bash test_gemma4_to_hf.sh $RUN_ID $CHECKPOINT_PATH $USE_MULTIMODAL $SCAN_LAYERS set -ex -idx=$(date +%Y-%m-%d-%H-%M) -MODEL_NAME='gemma4-26b' -export MODEL_VARIATION='26b-it' -# To convert the multimodal model, set USE_MULTIMODAL=true -USE_MULTIMODAL=false -# Set USE_SCAN_LAYERS=true if the checkpoint was trained with scanned layers -USE_SCAN_LAYERS=true - -# Installing torch for deps in forward_pass_logit_checker.py -python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# Non-Googlers: point MODEL_BUCKET to a GCS bucket you own. -export MODEL_BUCKET=gs://maxtext-gemma/gemma4 -# Path to a pre-existing MaxText checkpoint for gemma4-26b. Must match USE_SCAN_LAYERS. -# Run tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh to produce one. -export CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/converted/unscanned/0/items +run_id=$1 +CKPT_PATH=$2 +USE_MULTIMODAL=${3:-false} +SCAN_LAYERS=${4:-false} -# Path to the original HF model weights for logit comparison. -export HF_MODEL=google/gemma-4-26b-a4b-it +MODEL_NAME='gemma4-26b' +BASE_OUTPUT_DIRECTORY="gs://runner-maxtext-logs/${MODEL_NAME}" -export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx} +if [ "${SCAN_LAYERS,,}" = "true" ]; then + scan_status="scanned" +else + scan_status="unscanned" +fi python3 -m maxtext.checkpoint_conversion.to_huggingface \ - "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ model_name=${MODEL_NAME} \ - hf_access_token=${HF_TOKEN} \ + tokenizer_type="huggingface" \ load_parameters_path=${CKPT_PATH} \ - base_output_directory=${LOCAL_PATH} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/${scan_status}/${run_id} \ use_multimodal=${USE_MULTIMODAL} \ - scan_layers=${USE_SCAN_LAYERS} - -# Run forward pass logit checker to validate the converted checkpoint. -# The *_tile_fwd_*_dim flags are for reducing vmem usage to fit into v5p chips, -# not for performance purpose. -if [ "${USE_MULTIMODAL}" == true ]; then - TEST_PROMPT='Describe image <|image|>' - TEST_IMAGE='tests/assets/test_image.jpg' - export GOLDEN_LOGITS_PATH=/tmp/golden_gemma4_26b_vision.pickle - - python3 -m tests.assets.logits_generation.generate_hf_golden_logits \ - --model-id=${HF_MODEL} \ - --output-path=${GOLDEN_LOGITS_PATH} \ - --prompts="${TEST_PROMPT}" \ - --image-paths=${TEST_IMAGE} \ - --hf-model-path=${LOCAL_PATH} \ - --output-format=pickle - - echo "=== Running MaxText Forward Pass Logit Checker ===" - python3 -m tests.utils.forward_pass_logit_checker \ - "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ - tokenizer_path=${HF_MODEL} \ - load_parameters_path=${CKPT_PATH} \ - model_name=${MODEL_NAME} \ - use_multimodal=${USE_MULTIMODAL} \ - scan_layers=${USE_SCAN_LAYERS} \ - dtype=float32 \ - wi_tile_fwd_embed_dim=512 \ - wi_tile_fwd_mlp_dim=512 \ - wo_tile_fwd_embed_dim=512 \ - wo_tile_fwd_mlp_dim=512 \ - matmul_precision=highest \ - per_device_batch_size=1 \ - attention=dot_product \ - prompt="${TEST_PROMPT}" \ - image_path=${TEST_IMAGE} \ - --max_kl_div=0.1 \ - --golden_logits_path=${GOLDEN_LOGITS_PATH} -else - echo "=== Running MaxText Forward Pass Logit Checker ===" - python3 -m tests.utils.forward_pass_logit_checker \ - "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ - tokenizer_path=${HF_MODEL} \ - load_parameters_path=${CKPT_PATH} \ - model_name=${MODEL_NAME} \ - use_multimodal=${USE_MULTIMODAL} \ - scan_layers=${USE_SCAN_LAYERS} \ - per_device_batch_size=1 \ - dtype=float32 \ - wi_tile_fwd_embed_dim=512 \ - wi_tile_fwd_mlp_dim=512 \ - wo_tile_fwd_embed_dim=512 \ - wo_tile_fwd_mlp_dim=512 \ - --max_kl_div=0.1 \ - --run_hf_model=true \ - --hf_model_path=${LOCAL_PATH} -fi + scan_layers=$SCAN_LAYERS \ No newline at end of file diff --git a/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_mt.sh b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_mt.sh new file mode 100644 index 0000000000..8f93dce3a9 --- /dev/null +++ b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_mt.sh @@ -0,0 +1,74 @@ +#!/bin/bash + +# Converts Gemma4-26B HuggingFace checkpoint to MaxText format and validates logit correctness. + +# The flow of this script is as follows: +# 1. Install PyTorch (CPU) required for checkpoint conversion. +# 2. Convert the HuggingFace checkpoint to MaxText format in both unscanned and scanned formats. +# 3. Run a forward pass logits check to verify the converted checkpoint matches the original HF model. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) +# bash test_gemma4_to_mt.sh $RUN_ID - to convert the checkpoint and run logit check for non-multimodal version +# bash test_gemma4_to_mt.sh $RUN_ID true - to convert the checkpoint and run logit check for multimodal version + + +set -ex + +run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} +MODEL_NAME='gemma4-26b' +HF_GOLDEN_MODEL='google/gemma-4-26b-a4b-it' + +# To convert the multimodal model, make sure the use_multimodal is set to be true +USE_MULTIMODAL=${2:-false} + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you want to store scanned and unscanned checkpoints +BASE_OUTPUT_DIRECTORY=gs://us-central1-emmalien-test-83df3ecd-bucket/gemma4-26b-0625 + +# Step 1: Install torch +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Step 2: Convert the checkpoint from Hugging Face to make it compatible with MaxText + +# Step 2.a: Convert to unscanned checkpoint (for inference) +python3 -m maxtext.checkpoint_conversion.to_maxtext \ + model_name=${MODEL_NAME} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=false \ + hardware=cpu skip_jax_distributed_system=True \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False + +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id}/0/items +echo "Unscanned checkpoint path: ${UNSCANNED_CKPT_PATH}" + +# Step 2.b: Convert to scanned checkpoint (for training) +python3 -m maxtext.checkpoint_conversion.to_maxtext \ + model_name=${MODEL_NAME} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=true \ + hardware=cpu skip_jax_distributed_system=True \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False + +SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id}/0/items +echo "Scanned checkpoint path: ${SCANNED_CKPT_PATH}" + +# Step 3: Test whether the forward pass logits match the original HF model +# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` +# ToDo: improve forward_pass_logit_checker to test multi-modal prompt +if [ "${USE_MULTIMODAL}" = "false" ]; then + python3 -m tests.utils.forward_pass_logit_checker \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + model_name=${MODEL_NAME} \ + use_multimodal=${USE_MULTIMODAL} \ + per_device_batch_size=1 \ + dtype=float32 \ + attention=dot_product \ + scan_layers=false \ + --hf_model_path=${HF_GOLDEN_MODEL} \ + --max_kl_div=0.03 \ + --run_hf_model=true \ + hardware=cpu skip_jax_distributed_system=True +fi From fd387a7ce22aa694e29d8590e45852aedc46d916 Mon Sep 17 00:00:00 2001 From: Emma Lien Date: Fri, 26 Jun 2026 07:49:05 +0000 Subject: [PATCH 2/3] fix --- tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_mt.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_mt.sh b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_mt.sh index 8f93dce3a9..213a30d285 100644 --- a/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_mt.sh +++ b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_mt.sh @@ -24,7 +24,7 @@ HF_GOLDEN_MODEL='google/gemma-4-26b-a4b-it' USE_MULTIMODAL=${2:-false} # Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you want to store scanned and unscanned checkpoints -BASE_OUTPUT_DIRECTORY=gs://us-central1-emmalien-test-83df3ecd-bucket/gemma4-26b-0625 +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}/to_maxtext # Step 1: Install torch python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu From 8bcfa63fc1fbd52ca2798d3d0c71e3e49bd690fa Mon Sep 17 00:00:00 2001 From: Emma Lien Date: Fri, 26 Jun 2026 08:37:33 +0000 Subject: [PATCH 3/3] add post-trainging e2e scripts --- .../tpu/gemma4/26b/test_gemma4_rl.sh | 57 +++++++++++++++++++ .../tpu/gemma4/26b/test_gemma4_sft.sh | 55 ++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 tests/end_to_end/tpu/gemma4/26b/test_gemma4_rl.sh create mode 100644 tests/end_to_end/tpu/gemma4/26b/test_gemma4_sft.sh diff --git a/tests/end_to_end/tpu/gemma4/26b/test_gemma4_rl.sh b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_rl.sh new file mode 100644 index 0000000000..db1f2c0b34 --- /dev/null +++ b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_rl.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +# Validates the Gemma4-26B RL pipeline using a pre-converted MaxText checkpoint. + +# The flow of this script is as follows: +# 1. Run inference on the pre-converted checkpoint. +# 2. Run RL starting from the pre-converted checkpoint. +# 3. Run inference on the checkpoint produced by the RL run. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) +# bash test_gemma4_to_mt.sh $RUN_ID +# bash test_gemma4_rl.sh $RUN_ID + + +set -ex + +run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} +use_pathways=${2:-false} +MODEL_NAME='gemma4-26b' + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items +SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items + +# Step 1: Run inference on the original checkpoint converted from Hugging Face +python3 -m maxtext.inference.vllm_decode \ + model_name=${MODEL_NAME} \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.5 \ + prompt='Suggest some famous landmarks in London.' \ + use_chat_template=True scan_layers=false enable_single_controller=${use_pathways} + +# Step 2: Run RL on the converted checkpoint +python3 -m maxtext.trainers.post_train.rl.train_rl \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/rl \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + run_name=${run_id} rl.loss_algo='grpo' scan_layers=true \ + num_batches=5 batch_size=1 num_test_batches=5 \ + model_name=${MODEL_NAME} enable_single_controller=${use_pathways} \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + rollout_tensor_parallelism=1 \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + vllm_additional_config='{"maxtext_config": {"model_name": "gemma4-26b", "log_config": "false"}}' + + +# Step 3: Run inference on the checkpoint generated from the previous run +python3 -m maxtext.inference.vllm_decode \ + model_name=${MODEL_NAME} \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.5 \ + prompt='Suggest some famous landmarks in London.' \ + use_chat_template=True scan_layers=true enable_single_controller=${use_pathways} diff --git a/tests/end_to_end/tpu/gemma4/26b/test_gemma4_sft.sh b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_sft.sh new file mode 100644 index 0000000000..4453016565 --- /dev/null +++ b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_sft.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# Validates the Gemma4-26B SFT pipeline using a pre-converted MaxText checkpoint. + +# The flow of this script is as follows: +# 1. Run inference on the pre-converted checkpoint. +# 2. Run SFT starting from the pre-converted checkpoint. +# 3. Run inference on the checkpoint produced by the SFT run. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) +# bash test_gemma4_to_mt.sh $RUN_ID +# bash test_gemma4_sft.sh $RUN_ID + + +set -ex + +run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} +use_pathways=${2:-false} +MODEL_NAME='gemma4-26b' + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items +SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items + +# Step 1: Run inference on the original checkpoint converted from Hugging Face +python3 -m maxtext.inference.vllm_decode \ + model_name=${MODEL_NAME} \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.5 \ + prompt="Suggest some famous landmarks in London." \ + use_chat_template=True scan_layers=false enable_single_controller=${use_pathways} + +# Step 2: Run SFT on the converted checkpoint +python3 -m maxtext.trainers.post_train.sft.train_sft \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/sft \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + steps=5 scan_layers=true \ + model_name=${MODEL_NAME} enable_single_controller=${use_pathways} \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False + +# Step 3: Run inference on the checkpoint generated from the previous run +python3 -m maxtext.inference.vllm_decode \ + model_name=${MODEL_NAME} \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.5 \ + prompt="Suggest some famous landmarks in London." \ + use_chat_template=True scan_layers=true enable_single_controller=${use_pathways} + +