-
Notifications
You must be signed in to change notification settings - Fork 542
Add gemma4-26b E2E test scripts for pre-training #4278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
chiajunglien
wants to merge
3
commits into
AI-Hypercomputer:main
Choose a base branch
from
CIeNET-International:emma/e2e-training
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| #!/bin/bash | ||
|
|
||
| # Validates the Gemma4-26B pre-training pipeline using a pre-converted MaxText checkpoint. | ||
|
|
||
| # The flow of this script is as follows: | ||
| # 1. Run inference on the pre-converted checkpoint. | ||
| # 2. Run pre-training starting from the pre-converted checkpoint. | ||
| # 3. Run inference on the checkpoint produced by the pre-training run. | ||
|
|
||
| # Usage: | ||
| # export HF_TOKEN=<your Hugging Face access token> | ||
| # export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) | ||
| # bash test_gemma4_to_mt.sh $RUN_ID | ||
| # bash test_gemma4.sh $RUN_ID | ||
|
|
||
|
|
||
| set -ex | ||
|
|
||
| run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} | ||
| MODEL_NAME='gemma4-26b' | ||
|
|
||
| # To convert the multimodal model, make sure the use_multimodal is set to be true | ||
| USE_MULTIMODAL=false | ||
|
|
||
| # Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored | ||
| BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} | ||
| UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items | ||
|
|
||
| # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data | ||
| DATASET_PATH=gs://maxtext-dataset | ||
|
|
||
| # Step 1: Run inference on the original checkpoint converted from Hugging Face | ||
| if [ ${USE_MULTIMODAL} == true ]; then | ||
| python3 -m maxtext.inference.decode \ | ||
| model_name=${MODEL_NAME} tokenizer_type="huggingface" \ | ||
| load_parameters_path=${UNSCANNED_CKPT_PATH} \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ | ||
| scan_layers=false use_multimodal=true \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' | ||
| else | ||
| python3 -m maxtext.inference.decode \ | ||
| model_name=${MODEL_NAME} tokenizer_type="huggingface" \ | ||
| load_parameters_path=${UNSCANNED_CKPT_PATH} \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| scan_layers=false prompt='I love to' attention=\'dot_product\' | ||
| fi | ||
|
|
||
| # Step 2: Run Pre-training on the converted checkpoint | ||
| # We can also run training by using the scanned converted checkpoint | ||
| # Note that scanned checkpoint helps with efficient training | ||
| python3 -m maxtext.trainers.pre_train.train \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/train \ | ||
| dataset_path=${DATASET_PATH} tokenizer_type="huggingface" \ | ||
| load_parameters_path=${UNSCANNED_CKPT_PATH} \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| max_target_length=8192 steps=5 async_checkpointing=false \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL} | ||
|
|
||
| # Step 3: Run inference on the checkpoint generated from the previous run | ||
| if [ ${USE_MULTIMODAL} == true ]; then | ||
| python3 -m maxtext.inference.decode \ | ||
| model_name=${MODEL_NAME} tokenizer_type="huggingface" \ | ||
| load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ | ||
| scan_layers=false use_multimodal=true \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' | ||
| else | ||
| python3 -m maxtext.inference.decode \ | ||
| model_name=${MODEL_NAME} tokenizer_type="huggingface" \ | ||
| load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| scan_layers=false prompt='I love to' attention=\'dot_product\' | ||
| fi | ||
|
|
||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| #!/bin/bash | ||
|
|
||
| # Validates the Gemma4-26B RL pipeline using a pre-converted MaxText checkpoint. | ||
|
|
||
| # The flow of this script is as follows: | ||
| # 1. Run inference on the pre-converted checkpoint. | ||
| # 2. Run RL starting from the pre-converted checkpoint. | ||
| # 3. Run inference on the checkpoint produced by the RL run. | ||
|
|
||
| # Usage: | ||
| # export HF_TOKEN=<your Hugging Face access token> | ||
| # export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) | ||
| # bash test_gemma4_to_mt.sh $RUN_ID | ||
| # bash test_gemma4_rl.sh $RUN_ID | ||
|
|
||
|
|
||
| set -ex | ||
|
|
||
| run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} | ||
| use_pathways=${2:-false} | ||
| MODEL_NAME='gemma4-26b' | ||
|
|
||
| # Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored | ||
| BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} | ||
| UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items | ||
| SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items | ||
|
|
||
| # Step 1: Run inference on the original checkpoint converted from Hugging Face | ||
| python3 -m maxtext.inference.vllm_decode \ | ||
| model_name=${MODEL_NAME} \ | ||
| load_parameters_path=${UNSCANNED_CKPT_PATH} \ | ||
| vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ | ||
| hbm_utilization_vllm=0.5 \ | ||
| prompt='Suggest some famous landmarks in London.' \ | ||
| use_chat_template=True scan_layers=false enable_single_controller=${use_pathways} | ||
|
|
||
| # Step 2: Run RL on the converted checkpoint | ||
| python3 -m maxtext.trainers.post_train.rl.train_rl \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/rl \ | ||
| load_parameters_path=${SCANNED_CKPT_PATH} \ | ||
| run_name=${run_id} rl.loss_algo='grpo' scan_layers=true \ | ||
| num_batches=5 batch_size=1 num_test_batches=5 \ | ||
| model_name=${MODEL_NAME} enable_single_controller=${use_pathways} \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ | ||
| rollout_tensor_parallelism=1 \ | ||
| vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ | ||
| vllm_additional_config='{"maxtext_config": {"model_name": "gemma4-26b", "log_config": "false"}}' | ||
|
|
||
|
|
||
| # Step 3: Run inference on the checkpoint generated from the previous run | ||
| python3 -m maxtext.inference.vllm_decode \ | ||
| model_name=${MODEL_NAME} \ | ||
| load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \ | ||
| vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ | ||
| hbm_utilization_vllm=0.5 \ | ||
| prompt='Suggest some famous landmarks in London.' \ | ||
| use_chat_template=True scan_layers=true enable_single_controller=${use_pathways} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| #!/bin/bash | ||
|
|
||
| # Validates the Gemma4-26B SFT pipeline using a pre-converted MaxText checkpoint. | ||
|
|
||
| # The flow of this script is as follows: | ||
| # 1. Run inference on the pre-converted checkpoint. | ||
| # 2. Run SFT starting from the pre-converted checkpoint. | ||
| # 3. Run inference on the checkpoint produced by the SFT run. | ||
|
|
||
| # Usage: | ||
| # export HF_TOKEN=<your Hugging Face access token> | ||
| # export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) | ||
| # bash test_gemma4_to_mt.sh $RUN_ID | ||
| # bash test_gemma4_sft.sh $RUN_ID | ||
|
|
||
|
|
||
| set -ex | ||
|
|
||
| run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} | ||
| use_pathways=${2:-false} | ||
| MODEL_NAME='gemma4-26b' | ||
|
|
||
| # Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored | ||
| BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} | ||
| UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items | ||
| SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items | ||
|
|
||
| # Step 1: Run inference on the original checkpoint converted from Hugging Face | ||
| python3 -m maxtext.inference.vllm_decode \ | ||
| model_name=${MODEL_NAME} \ | ||
| load_parameters_path=${UNSCANNED_CKPT_PATH} \ | ||
| vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ | ||
| hbm_utilization_vllm=0.5 \ | ||
| prompt="Suggest some famous landmarks in London." \ | ||
| use_chat_template=True scan_layers=false enable_single_controller=${use_pathways} | ||
|
|
||
| # Step 2: Run SFT on the converted checkpoint | ||
| python3 -m maxtext.trainers.post_train.sft.train_sft \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/sft \ | ||
| load_parameters_path=${SCANNED_CKPT_PATH} \ | ||
| per_device_batch_size=1 run_name=${run_id} \ | ||
| steps=5 scan_layers=true \ | ||
| model_name=${MODEL_NAME} enable_single_controller=${use_pathways} \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False | ||
|
|
||
| # Step 3: Run inference on the checkpoint generated from the previous run | ||
| python3 -m maxtext.inference.vllm_decode \ | ||
| model_name=${MODEL_NAME} \ | ||
| load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \ | ||
| vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ | ||
| hbm_utilization_vllm=0.5 \ | ||
| prompt="Suggest some famous landmarks in London." \ | ||
| use_chat_template=True scan_layers=true enable_single_controller=${use_pathways} | ||
|
|
||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,103 +1,31 @@ | ||
| #!/bin/bash | ||
|
|
||
| # This script is both an end-to-end test and documentation for converting a | ||
| # Gemma4-26B MaxText checkpoint to Hugging Face format. Can be run on a v5p-8. | ||
| # Converts a MaxText checkpoint to a Hugging Face model checkpoint. | ||
|
|
||
| # The flow of this script is as follows: | ||
| # 1. Convert a MaxText checkpoint to a Hugging Face model checkpoint. | ||
| # 2. Run a forward pass check to compare the logits and KL divergence between | ||
| # the converted checkpoint and the original HF model. | ||
|
|
||
| # Pre-requisites: | ||
| # 1. Set HF_TOKEN environment variable to your Hugging Face access token. | ||
| # export HF_TOKEN=<Hugging Face access token> | ||
| # 2. Provide a MaxText-format Gemma4-26B checkpoint via CKPT_PATH. | ||
| # One can be produced with tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh. | ||
| # Usage: | ||
| # export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) | ||
| # bash test_gemma4_to_hf.sh $RUN_ID $CHECKPOINT_PATH $USE_MULTIMODAL $SCAN_LAYERS | ||
|
|
||
| set -ex | ||
| idx=$(date +%Y-%m-%d-%H-%M) | ||
| MODEL_NAME='gemma4-26b' | ||
| export MODEL_VARIATION='26b-it' | ||
| # To convert the multimodal model, set USE_MULTIMODAL=true | ||
| USE_MULTIMODAL=false | ||
| # Set USE_SCAN_LAYERS=true if the checkpoint was trained with scanned layers | ||
| USE_SCAN_LAYERS=true | ||
|
|
||
| # Installing torch for deps in forward_pass_logit_checker.py | ||
| python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu | ||
|
|
||
| # Non-Googlers: point MODEL_BUCKET to a GCS bucket you own. | ||
| export MODEL_BUCKET=gs://maxtext-gemma/gemma4 | ||
| # Path to a pre-existing MaxText checkpoint for gemma4-26b. Must match USE_SCAN_LAYERS. | ||
| # Run tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh to produce one. | ||
| export CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/converted/unscanned/0/items | ||
| run_id=$1 | ||
| CKPT_PATH=$2 | ||
| USE_MULTIMODAL=${3:-false} | ||
| SCAN_LAYERS=${4:-false} | ||
|
|
||
| # Path to the original HF model weights for logit comparison. | ||
| export HF_MODEL=google/gemma-4-26b-a4b-it | ||
| MODEL_NAME='gemma4-26b' | ||
| BASE_OUTPUT_DIRECTORY="gs://runner-maxtext-logs/${MODEL_NAME}" | ||
|
|
||
| export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx} | ||
| if [ "${SCAN_LAYERS,,}" = "true" ]; then | ||
| scan_status="scanned" | ||
| else | ||
| scan_status="unscanned" | ||
| fi | ||
|
|
||
| python3 -m maxtext.checkpoint_conversion.to_huggingface \ | ||
| "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ | ||
| model_name=${MODEL_NAME} \ | ||
| hf_access_token=${HF_TOKEN} \ | ||
| tokenizer_type="huggingface" \ | ||
| load_parameters_path=${CKPT_PATH} \ | ||
| base_output_directory=${LOCAL_PATH} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/${scan_status}/${run_id} \ | ||
| use_multimodal=${USE_MULTIMODAL} \ | ||
| scan_layers=${USE_SCAN_LAYERS} | ||
|
|
||
| # Run forward pass logit checker to validate the converted checkpoint. | ||
| # The *_tile_fwd_*_dim flags are for reducing vmem usage to fit into v5p chips, | ||
| # not for performance purpose. | ||
| if [ "${USE_MULTIMODAL}" == true ]; then | ||
| TEST_PROMPT='Describe image <|image|>' | ||
| TEST_IMAGE='tests/assets/test_image.jpg' | ||
| export GOLDEN_LOGITS_PATH=/tmp/golden_gemma4_26b_vision.pickle | ||
|
|
||
| python3 -m tests.assets.logits_generation.generate_hf_golden_logits \ | ||
| --model-id=${HF_MODEL} \ | ||
| --output-path=${GOLDEN_LOGITS_PATH} \ | ||
| --prompts="${TEST_PROMPT}" \ | ||
| --image-paths=${TEST_IMAGE} \ | ||
| --hf-model-path=${LOCAL_PATH} \ | ||
| --output-format=pickle | ||
|
|
||
| echo "=== Running MaxText Forward Pass Logit Checker ===" | ||
| python3 -m tests.utils.forward_pass_logit_checker \ | ||
| "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ | ||
| tokenizer_path=${HF_MODEL} \ | ||
| load_parameters_path=${CKPT_PATH} \ | ||
| model_name=${MODEL_NAME} \ | ||
| use_multimodal=${USE_MULTIMODAL} \ | ||
| scan_layers=${USE_SCAN_LAYERS} \ | ||
| dtype=float32 \ | ||
| wi_tile_fwd_embed_dim=512 \ | ||
| wi_tile_fwd_mlp_dim=512 \ | ||
| wo_tile_fwd_embed_dim=512 \ | ||
| wo_tile_fwd_mlp_dim=512 \ | ||
| matmul_precision=highest \ | ||
| per_device_batch_size=1 \ | ||
| attention=dot_product \ | ||
| prompt="${TEST_PROMPT}" \ | ||
| image_path=${TEST_IMAGE} \ | ||
| --max_kl_div=0.1 \ | ||
| --golden_logits_path=${GOLDEN_LOGITS_PATH} | ||
| else | ||
| echo "=== Running MaxText Forward Pass Logit Checker ===" | ||
| python3 -m tests.utils.forward_pass_logit_checker \ | ||
| "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ | ||
| tokenizer_path=${HF_MODEL} \ | ||
| load_parameters_path=${CKPT_PATH} \ | ||
| model_name=${MODEL_NAME} \ | ||
| use_multimodal=${USE_MULTIMODAL} \ | ||
| scan_layers=${USE_SCAN_LAYERS} \ | ||
| per_device_batch_size=1 \ | ||
| dtype=float32 \ | ||
| wi_tile_fwd_embed_dim=512 \ | ||
| wi_tile_fwd_mlp_dim=512 \ | ||
| wo_tile_fwd_embed_dim=512 \ | ||
| wo_tile_fwd_mlp_dim=512 \ | ||
| --max_kl_div=0.1 \ | ||
| --run_hf_model=true \ | ||
| --hf_model_path=${LOCAL_PATH} | ||
| fi | ||
| scan_layers=$SCAN_LAYERS |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| #!/bin/bash | ||
|
|
||
| # Converts Gemma4-26B HuggingFace checkpoint to MaxText format and validates logit correctness. | ||
|
|
||
| # The flow of this script is as follows: | ||
| # 1. Install PyTorch (CPU) required for checkpoint conversion. | ||
| # 2. Convert the HuggingFace checkpoint to MaxText format in both unscanned and scanned formats. | ||
| # 3. Run a forward pass logits check to verify the converted checkpoint matches the original HF model. | ||
|
|
||
| # Usage: | ||
| # export HF_TOKEN=<your Hugging Face access token> | ||
| # export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S) | ||
| # bash test_gemma4_to_mt.sh $RUN_ID - to convert the checkpoint and run logit check for non-multimodal version | ||
| # bash test_gemma4_to_mt.sh $RUN_ID true - to convert the checkpoint and run logit check for multimodal version | ||
|
|
||
|
|
||
| set -ex | ||
|
|
||
| run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)} | ||
| MODEL_NAME='gemma4-26b' | ||
| HF_GOLDEN_MODEL='google/gemma-4-26b-a4b-it' | ||
|
|
||
| # To convert the multimodal model, make sure the use_multimodal is set to be true | ||
| USE_MULTIMODAL=${2:-false} | ||
|
|
||
| # Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you want to store scanned and unscanned checkpoints | ||
| BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}/to_maxtext | ||
|
|
||
| # Step 1: Install torch | ||
| python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu | ||
|
|
||
| # Step 2: Convert the checkpoint from Hugging Face to make it compatible with MaxText | ||
|
|
||
| # Step 2.a: Convert to unscanned checkpoint (for inference) | ||
| python3 -m maxtext.checkpoint_conversion.to_maxtext \ | ||
| model_name=${MODEL_NAME} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id} \ | ||
| use_multimodal=${USE_MULTIMODAL} \ | ||
| scan_layers=false \ | ||
| hardware=cpu skip_jax_distributed_system=True \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False | ||
|
|
||
| UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id}/0/items | ||
| echo "Unscanned checkpoint path: ${UNSCANNED_CKPT_PATH}" | ||
|
|
||
| # Step 2.b: Convert to scanned checkpoint (for training) | ||
| python3 -m maxtext.checkpoint_conversion.to_maxtext \ | ||
| model_name=${MODEL_NAME} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id} \ | ||
| use_multimodal=${USE_MULTIMODAL} \ | ||
| scan_layers=true \ | ||
| hardware=cpu skip_jax_distributed_system=True \ | ||
| checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False | ||
|
|
||
| SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id}/0/items | ||
| echo "Scanned checkpoint path: ${SCANNED_CKPT_PATH}" | ||
|
|
||
| # Step 3: Test whether the forward pass logits match the original HF model | ||
| # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` | ||
| # ToDo: improve forward_pass_logit_checker to test multi-modal prompt | ||
| if [ "${USE_MULTIMODAL}" = "false" ]; then | ||
| python3 -m tests.utils.forward_pass_logit_checker \ | ||
| load_parameters_path=${UNSCANNED_CKPT_PATH} \ | ||
| model_name=${MODEL_NAME} \ | ||
| use_multimodal=${USE_MULTIMODAL} \ | ||
| per_device_batch_size=1 \ | ||
| dtype=float32 \ | ||
| attention=dot_product \ | ||
| scan_layers=false \ | ||
| --hf_model_path=${HF_GOLDEN_MODEL} \ | ||
| --max_kl_div=0.03 \ | ||
| --run_hf_model=true \ | ||
| hardware=cpu skip_jax_distributed_system=True | ||
| fi | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please test whether we still need it.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the command below to test:
and got the error:
Seems it still need torch to be installed.