-
Notifications
You must be signed in to change notification settings - Fork 542
Add llama3.1-70b E2E test scripts for pre-training #4279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
chiajunglien
wants to merge
2
commits into
AI-Hypercomputer:main
Choose a base branch
from
CIeNET-International:emma/e2e-training-llama3-70b
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+186
−0
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| #!/bin/bash | ||
|
|
||
| # Validates the Gemma3-4B pre-training pipeline using a pre-converted MaxText checkpoint. | ||
|
|
||
| # The flow of this script is as follows: | ||
| # 1. Run inference on the pre-converted checkpoint. | ||
| # 2. Run pre-training starting from the pre-converted checkpoint. | ||
| # 3. Run inference on the checkpoint produced by the pre-training run. | ||
|
|
||
| # Usage: | ||
| # export HF_TOKEN=<your Hugging Face access token> | ||
| # export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) | ||
| # bash test_gemma3_to_mt.sh $RUN_ID | ||
| # bash test_gemma3.sh $RUN_ID | ||
|
|
||
|
|
||
| set -ex | ||
|
|
||
| run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} | ||
| MODEL_NAME='llama3.1-70b' | ||
|
|
||
| # To convert the multimodal model, make sure the use_multimodal is set to be true | ||
| USE_MULTIMODAL=false | ||
|
|
||
| # Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored | ||
| BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} | ||
| UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items | ||
|
|
||
| # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data | ||
| DATASET_PATH=gs://maxtext-dataset | ||
|
|
||
| # Step 1: Run inference on the original checkpoint converted from Hugging Face | ||
| if [ ${USE_MULTIMODAL} == true ]; then | ||
| python3 -m maxtext.inference.decode \ | ||
| model_name=${MODEL_NAME} tokenizer_type="huggingface" \ | ||
| load_parameters_path=${UNSCANNED_CKPT_PATH} \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ | ||
| scan_layers=false use_multimodal=true \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' | ||
| else | ||
| python3 -m maxtext.inference.decode \ | ||
| model_name=${MODEL_NAME} tokenizer_type="huggingface" \ | ||
| load_parameters_path=${UNSCANNED_CKPT_PATH} \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| scan_layers=false prompt='I love to' attention=\'dot_product\' | ||
| fi | ||
|
|
||
| # Step 2: Run Pre-training on the converted checkpoint | ||
| # We can also run training by using the scanned converted checkpoint | ||
| # Note that scanned checkpoint helps with efficient training | ||
| python3 -m maxtext.trainers.pre_train.train \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/train \ | ||
| dataset_path=${DATASET_PATH} tokenizer_type="huggingface" \ | ||
| load_parameters_path=${UNSCANNED_CKPT_PATH} \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| max_target_length=8192 steps=5 async_checkpointing=false \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL} | ||
|
|
||
| # Step 3: Run inference on the checkpoint generated from the previous run | ||
| if [ ${USE_MULTIMODAL} == true ]; then | ||
| python3 -m maxtext.inference.decode \ | ||
| model_name=${MODEL_NAME} tokenizer_type="huggingface" \ | ||
| load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ | ||
| scan_layers=false use_multimodal=true \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' | ||
| else | ||
| python3 -m maxtext.inference.decode \ | ||
| model_name=${MODEL_NAME} tokenizer_type="huggingface" \ | ||
| load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| scan_layers=false prompt='I love to' attention=\'dot_product\' | ||
| fi | ||
|
|
||
|
|
31 changes: 31 additions & 0 deletions
31
tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_hf.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| #!/bin/bash | ||
|
|
||
| # Converts a MaxText checkpoint to a Hugging Face model checkpoint. | ||
|
|
||
| # Usage: | ||
| # export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) | ||
| # bash test_gemma3_to_hf.sh $RUN_ID $CHECKPOINT_PATH $USE_MULTIMODAL $SCAN_LAYERS | ||
|
|
||
| set -ex | ||
|
|
||
| run_id=$1 | ||
| CKPT_PATH=$2 | ||
| USE_MULTIMODAL=${3:-false} | ||
| SCAN_LAYERS=${4:-false} | ||
|
|
||
| MODEL_NAME='llama3.1-70b' | ||
| BASE_OUTPUT_DIRECTORY="gs://runner-maxtext-logs/${MODEL_NAME}" | ||
|
|
||
| if [ "${SCAN_LAYERS,,}" = "true" ]; then | ||
| scan_status="scanned" | ||
| else | ||
| scan_status="unscanned" | ||
| fi | ||
|
|
||
| python3 -m maxtext.checkpoint_conversion.to_huggingface \ | ||
| model_name=${MODEL_NAME} \ | ||
| tokenizer_type="huggingface" \ | ||
| load_parameters_path=${CKPT_PATH} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/${scan_status}/${run_id} \ | ||
| use_multimodal=${USE_MULTIMODAL} \ | ||
| scan_layers=$SCAN_LAYERS |
71 changes: 71 additions & 0 deletions
71
tests/end_to_end/tpu/llama3.1/70b/test_llama3.1_70b_to_mt.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| #!/bin/bash | ||
|
|
||
| # Converts Gemma3-4B HuggingFace checkpoint to MaxText format and validates logit correctness. | ||
|
|
||
| # The flow of this script is as follows: | ||
| # 1. Install PyTorch (CPU) required for checkpoint conversion. | ||
| # 2. Convert the HuggingFace checkpoint to MaxText format in both unscanned and scanned formats. | ||
| # 3. Run a forward pass logits check to verify the converted checkpoint matches the original HF model. | ||
|
|
||
| # Usage: | ||
| # export HF_TOKEN=<your Hugging Face access token> | ||
| # export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) | ||
| # bash test_gemma3_to_mt.sh $RUN_ID - to convert the checkpoint and run logit check for non-multimodal version | ||
| # bash test_gemma3_to_mt.sh $RUN_ID true - to convert the checkpoint and run logit check for multimodal version | ||
|
|
||
|
|
||
| set -ex | ||
|
|
||
| run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} | ||
| MODEL_NAME='llama3.1-70b' | ||
| HF_GOLDEN_MODEL='meta-llama/Llama-3.1-70B' | ||
|
|
||
| # To convert the multimodal model, make sure the use_multimodal is set to be true | ||
| USE_MULTIMODAL=${2:-false} | ||
|
|
||
| # Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you want to store scanned and unscanned checkpoints | ||
| BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}/to_maxtext | ||
|
|
||
| # Step 1: Install torch | ||
| python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu | ||
|
|
||
| # Step 2: Convert the checkpoint from Hugging Face to make it compatible with MaxText | ||
|
|
||
| # Step 2.a: Convert to unscanned checkpoint (for inference) | ||
| python3 -m maxtext.checkpoint_conversion.to_maxtext \ | ||
| model_name=${MODEL_NAME} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id} \ | ||
| use_multimodal=${USE_MULTIMODAL} \ | ||
| scan_layers=false \ | ||
| hardware=cpu skip_jax_distributed_system=True \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False | ||
|
|
||
| UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id}/0/items | ||
| echo "Unscanned checkpoint path: ${UNSCANNED_CKPT_PATH}" | ||
|
|
||
| # Step 2.b: Convert to scanned checkpoint (for training) | ||
| python3 -m maxtext.checkpoint_conversion.to_maxtext \ | ||
| model_name=${MODEL_NAME} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id} \ | ||
| use_multimodal=${USE_MULTIMODAL} \ | ||
| scan_layers=true \ | ||
| hardware=cpu skip_jax_distributed_system=True \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False | ||
|
|
||
| SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id}/0/items | ||
| echo "Scanned checkpoint path: ${SCANNED_CKPT_PATH}" | ||
|
|
||
| # Step 3: Test whether the forward pass logits match the original HF model | ||
| # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` | ||
| # ToDo: improve forward_pass_logit_checker to test multi-modal prompt | ||
| if [ "${USE_MULTIMODAL}" = "false" ]; then | ||
| python3 -m tests.utils.forward_pass_logit_checker \ | ||
| load_parameters_path=${UNSCANNED_CKPT_PATH} \ | ||
| model_name=${MODEL_NAME} \ | ||
| use_multimodal=${USE_MULTIMODAL} \ | ||
| scan_layers=false \ | ||
| --hf_model_path=${HF_GOLDEN_MODEL} \ | ||
| --max_kl_div=0.03 \ | ||
| --run_hf_model=true \ | ||
| hardware=cpu skip_jax_distributed_system=True | ||
| fi | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please test whether we still need it.