Resolve step timing misattribution and isolate checkpoint overhead in MaxText.#4270
Open
copybara-service[bot] wants to merge 1 commit into
Open
Resolve step timing misattribution and isolate checkpoint overhead in MaxText.#4270copybara-service[bot] wants to merge 1 commit into
copybara-service[bot] wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
e4e574c to
cdc765a
Compare
… MaxText. Previously, MaxText measured training step times using host-side wall-clock intervals between loop iterations (`now() - last_step_completion`). Because JAX dispatches steps asynchronously and the metric logger flushes metrics buffered from step N-1, this logic was broken in several ways: 1. Inaccurate Step 0 Time: The reported time for step 0 was artificially low because it only measured the asynchronous dispatch time. 2. Attribution Shift: Measuring `now() - last_step_completion` effectively measured the time taken for step N-1 to complete, but attributed it to step N+1. 3. Missing Overheads: Checkpointing, HLO dumps, eval, and profiling times were not included in the step time calculations. 4. Broken Checkpoint Logs: Checkpointing acts as a synchronization barrier for step N. Because checkpointing blocked the host to perform D2H tensor transfers before the flush, it broke the logs around checkpoint boundaries, causing subsequent steps to enqueue instantly and artificially stall later. We fix this by measuring the interval between log flushes, where the loss is actually synchronized: - Eager Loss Synchronization before Step Time calculation: We force host-device synchronization on the main thread by converting the loss JAX array to a float inside `_flush_one_buffered_entry`. The step time is now calculated as the duration between consecutive eager loss synchronizations. Note that this synchronization was already occurring subsequently during final metrics serialization; moving it forward simply establishes a clean boundary to accurately measure step times. - Loop Reordering: `buffer_and_write_metrics` is moved before `maybe_save_checkpoint`. This ensures that metrics for step N-1 are flushed before step N's checkpointing sync barrier blocks the host thread. - Timer Cleanup: Inaccurate host-loop timing variables (`last_step_completion` and `step_time_delta`) are removed. This guarantees that all physical blocking copy overhead and host CPU serialization contention are accurately captured and cleanly attributed to the correct steps, eliminating the anomalies around step 0 and checkpoint boundaries. # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation. PiperOrigin-RevId: 936672580
cdc765a to
1e01ed2
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Resolve step timing misattribution and isolate checkpoint overhead in MaxText.
Previously, MaxText measured training step times using host-side wall-clock intervals between loop iterations (
now() - last_step_completion). Because JAX dispatches steps asynchronously and the metric logger flushes metrics buffered from step N-1, this logic was broken in several ways:now() - last_step_completioneffectively measured the time taken for step N-1 to complete, but attributed it to step N+1.We fix this by measuring the interval between log flushes, where the loss is actually synchronized:
_flush_one_buffered_entry. The step time is now calculated as the duration between consecutive eager loss synchronizations. Note that this synchronization was already occurring subsequently during final metrics serialization; moving it forward simply establishes a clean boundary to accurately measure step times.buffer_and_write_metricsis moved beforemaybe_save_checkpoint. This ensures that metrics for step N-1 are flushed before step N's checkpointing sync barrier blocks the host thread.last_step_completionandstep_time_delta) are removed.This guarantees that all physical blocking copy overhead and host CPU serialization contention are accurately captured and cleanly attributed to the correct steps, eliminating the anomalies around step 0 and checkpoint boundaries.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.