From 1e01ed2bf20c1202dc7e61292495d9cae69c56bc Mon Sep 17 00:00:00 2001 From: Aananth V Date: Tue, 23 Jun 2026 08:01:43 -0700 Subject: [PATCH] Resolve step timing misattribution and isolate checkpoint overhead in MaxText. Previously, MaxText measured training step times using host-side wall-clock intervals between loop iterations (`now() - last_step_completion`). Because JAX dispatches steps asynchronously and the metric logger flushes metrics buffered from step N-1, this logic was broken in several ways: 1. Inaccurate Step 0 Time: The reported time for step 0 was artificially low because it only measured the asynchronous dispatch time. 2. Attribution Shift: Measuring `now() - last_step_completion` effectively measured the time taken for step N-1 to complete, but attributed it to step N+1. 3. Missing Overheads: Checkpointing, HLO dumps, eval, and profiling times were not included in the step time calculations. 4. Broken Checkpoint Logs: Checkpointing acts as a synchronization barrier for step N. Because checkpointing blocked the host to perform D2H tensor transfers before the flush, it broke the logs around checkpoint boundaries, causing subsequent steps to enqueue instantly and artificially stall later. We fix this by measuring the interval between log flushes, where the loss is actually synchronized: - Eager Loss Synchronization before Step Time calculation: We force host-device synchronization on the main thread by converting the loss JAX array to a float inside `_flush_one_buffered_entry`. The step time is now calculated as the duration between consecutive eager loss synchronizations. Note that this synchronization was already occurring subsequently during final metrics serialization; moving it forward simply establishes a clean boundary to accurately measure step times. - Loop Reordering: `buffer_and_write_metrics` is moved before `maybe_save_checkpoint`. This ensures that metrics for step N-1 are flushed before step N's checkpointing sync barrier blocks the host thread. - Timer Cleanup: Inaccurate host-loop timing variables (`last_step_completion` and `step_time_delta`) are removed. This guarantees that all physical blocking copy overhead and host CPU serialization contention are accurately captured and cleanly attributed to the correct steps, eliminating the anomalies around step 0 and checkpoint boundaries. # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation. PiperOrigin-RevId: 936672580 --- src/maxtext/common/metric_logger.py | 29 ++++++++++++++++++++----- src/maxtext/trainers/pre_train/train.py | 14 ++---------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 472ede809e..496b23c6bd 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -16,6 +16,7 @@ # pytype: disable=attribute-error """Logger that saves metrics to a local file, GCS and TensorBoard.""" +import datetime import json import os import sys @@ -111,10 +112,14 @@ def __init__(self, config, learning_rate_schedule): if self.config.managed_mldiagnostics: ManagedMLDiagnostics(config) # Initialize the MLRun instance. + self.last_flush_time = datetime.datetime.now() + self.last_eval_flush_time = datetime.datetime.now() + def reset_eval_metrics(self): """Resets the cumulative metrics dictionary for a new evaluation run.""" self.cumulative_eval_metrics = {"scalar": defaultdict(float)} self._pending_eval_step_count = 0 + self.last_eval_flush_time = datetime.datetime.now() def write_metrics(self, metrics, step, metric_type="train"): """Entry point for all metrics writing. metric_type is one of 'train', 'eval', 'running_eval'.""" @@ -383,24 +388,36 @@ def buffer_and_write_metrics(self, metrics, step, step_time_delta=None, is_train if self.buffered_metrics: self._flush_one_buffered_entry(self.buffered_metrics.pop(0)) if is_training: - self.record_train_metrics(metrics, step, step_time_delta.total_seconds()) - self.buffered_metrics.append(("train", step, metrics, step_time_delta)) + self.buffered_metrics.append(("train", step, metrics)) if self._pending_eval_step_count > 0: self._finalize_eval_metrics(step) else: self._pending_eval_step_count += 1 - self.buffered_metrics.append(("eval", step, metrics, step_time_delta)) + self.buffered_metrics.append(("eval", step, metrics)) def _flush_one_buffered_entry(self, entry): """Dispatches a single buffered entry to the writer.""" kind = entry[0] if kind == "train": - _, step, metrics, _ = entry + _, step, metrics = entry + # Synchronize the loss before recording step time. + _ = float(metrics["scalar"]["learning/loss"]) + + current_time = datetime.datetime.now() + real_step_time = (current_time - self.last_flush_time).total_seconds() + self.last_flush_time = current_time + + self.record_train_metrics(metrics, step, real_step_time) self.write_metrics(metrics, step) elif kind == "eval": - _, eval_step, raw_metrics, step_time_delta = entry + _, eval_step, raw_metrics = entry # _accumulate_eval_metrics calls float() that materialize the metrics, deferred to here self._accumulate_eval_metrics(raw_metrics) + + current_time = datetime.datetime.now() + real_eval_step_time = (current_time - self.last_eval_flush_time).total_seconds() + self.last_eval_flush_time = current_time + running_count = eval_step + 1 # eval_step is 0-indexed cumulative = self.cumulative_eval_metrics["scalar"] running_avg_loss = cumulative["eval/total_loss"] / (cumulative["eval/total_weights"] + EPS) @@ -411,7 +428,7 @@ def _flush_one_buffered_entry(self, entry): "eval/total_weights": cumulative["eval/total_weights"], "eval/avg_mtp_loss": cumulative["eval/mtp_loss"] / running_count, "eval/avg_mtp_acceptance_rate_percent": (cumulative["eval/mtp_acceptance_rate_percent"] / running_count), - "eval/step_time_seconds": step_time_delta.total_seconds(), + "eval/step_time_seconds": real_eval_step_time, } } self.write_metrics(snapshot, eval_step, metric_type="running_eval") diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 74c41e8e90..8eb00fca5e 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -19,7 +19,6 @@ # See github.com/google/maxtext/issues/20 for more from typing import Any, Sequence -import datetime import functools import os @@ -677,7 +676,6 @@ def train_loop(config, recorder, state=None): _job_completed_gracefully = False try: - last_step_completion = datetime.datetime.now() for step in np.arange(start_step, config.steps): prof.maybe_activate_profiler(step, state) @@ -694,7 +692,7 @@ def train_loop(config, recorder, state=None): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) state, metrics = p_train_step(state, example_batch, *step_rng_args) - step_time_delta = datetime.datetime.now() - last_step_completion + metric_logger_instance.buffer_and_write_metrics(metrics, step) checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step) @@ -717,7 +715,6 @@ def train_loop(config, recorder, state=None): max_logging.log(f"Starting eval after train step {step}") eval_step_count = 0 - last_eval_step_completion = datetime.datetime.now() # pylint: disable=not-callable for eval_batch in eval_data_iterator: # Shard input eval data @@ -726,11 +723,7 @@ def train_loop(config, recorder, state=None): break with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): eval_metrics = p_eval_step(state, eval_batch, *step_rng_args) - eval_step_time_delta = datetime.datetime.now() - last_eval_step_completion - last_eval_step_completion = datetime.datetime.now() - metric_logger_instance.buffer_and_write_metrics( - eval_metrics, eval_step_count, step_time_delta=eval_step_time_delta, is_training=False - ) + metric_logger_instance.buffer_and_write_metrics(eval_metrics, eval_step_count, is_training=False) eval_step_count += 1 prof.maybe_deactivate_profiler(step, state) @@ -738,9 +731,6 @@ def train_loop(config, recorder, state=None): if step == start_step: max_utils.print_mem_stats("After params initialized") - last_step_completion = datetime.datetime.now() - metric_logger_instance.buffer_and_write_metrics(metrics, step, step_time_delta) - if config.save_checkpoint_on_completion: checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator) if checkpoint_manager is not None: