diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 472ede809e..496b23c6bd 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -16,6 +16,7 @@ # pytype: disable=attribute-error """Logger that saves metrics to a local file, GCS and TensorBoard.""" +import datetime import json import os import sys @@ -111,10 +112,14 @@ def __init__(self, config, learning_rate_schedule): if self.config.managed_mldiagnostics: ManagedMLDiagnostics(config) # Initialize the MLRun instance. + self.last_flush_time = datetime.datetime.now() + self.last_eval_flush_time = datetime.datetime.now() + def reset_eval_metrics(self): """Resets the cumulative metrics dictionary for a new evaluation run.""" self.cumulative_eval_metrics = {"scalar": defaultdict(float)} self._pending_eval_step_count = 0 + self.last_eval_flush_time = datetime.datetime.now() def write_metrics(self, metrics, step, metric_type="train"): """Entry point for all metrics writing. metric_type is one of 'train', 'eval', 'running_eval'.""" @@ -383,24 +388,36 @@ def buffer_and_write_metrics(self, metrics, step, step_time_delta=None, is_train if self.buffered_metrics: self._flush_one_buffered_entry(self.buffered_metrics.pop(0)) if is_training: - self.record_train_metrics(metrics, step, step_time_delta.total_seconds()) - self.buffered_metrics.append(("train", step, metrics, step_time_delta)) + self.buffered_metrics.append(("train", step, metrics)) if self._pending_eval_step_count > 0: self._finalize_eval_metrics(step) else: self._pending_eval_step_count += 1 - self.buffered_metrics.append(("eval", step, metrics, step_time_delta)) + self.buffered_metrics.append(("eval", step, metrics)) def _flush_one_buffered_entry(self, entry): """Dispatches a single buffered entry to the writer.""" kind = entry[0] if kind == "train": - _, step, metrics, _ = entry + _, step, metrics = entry + # Synchronize the loss before recording step time. + _ = float(metrics["scalar"]["learning/loss"]) + + current_time = datetime.datetime.now() + real_step_time = (current_time - self.last_flush_time).total_seconds() + self.last_flush_time = current_time + + self.record_train_metrics(metrics, step, real_step_time) self.write_metrics(metrics, step) elif kind == "eval": - _, eval_step, raw_metrics, step_time_delta = entry + _, eval_step, raw_metrics = entry # _accumulate_eval_metrics calls float() that materialize the metrics, deferred to here self._accumulate_eval_metrics(raw_metrics) + + current_time = datetime.datetime.now() + real_eval_step_time = (current_time - self.last_eval_flush_time).total_seconds() + self.last_eval_flush_time = current_time + running_count = eval_step + 1 # eval_step is 0-indexed cumulative = self.cumulative_eval_metrics["scalar"] running_avg_loss = cumulative["eval/total_loss"] / (cumulative["eval/total_weights"] + EPS) @@ -411,7 +428,7 @@ def _flush_one_buffered_entry(self, entry): "eval/total_weights": cumulative["eval/total_weights"], "eval/avg_mtp_loss": cumulative["eval/mtp_loss"] / running_count, "eval/avg_mtp_acceptance_rate_percent": (cumulative["eval/mtp_acceptance_rate_percent"] / running_count), - "eval/step_time_seconds": step_time_delta.total_seconds(), + "eval/step_time_seconds": real_eval_step_time, } } self.write_metrics(snapshot, eval_step, metric_type="running_eval") diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 74c41e8e90..8eb00fca5e 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -19,7 +19,6 @@ # See github.com/google/maxtext/issues/20 for more from typing import Any, Sequence -import datetime import functools import os @@ -677,7 +676,6 @@ def train_loop(config, recorder, state=None): _job_completed_gracefully = False try: - last_step_completion = datetime.datetime.now() for step in np.arange(start_step, config.steps): prof.maybe_activate_profiler(step, state) @@ -694,7 +692,7 @@ def train_loop(config, recorder, state=None): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) state, metrics = p_train_step(state, example_batch, *step_rng_args) - step_time_delta = datetime.datetime.now() - last_step_completion + metric_logger_instance.buffer_and_write_metrics(metrics, step) checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step) @@ -717,7 +715,6 @@ def train_loop(config, recorder, state=None): max_logging.log(f"Starting eval after train step {step}") eval_step_count = 0 - last_eval_step_completion = datetime.datetime.now() # pylint: disable=not-callable for eval_batch in eval_data_iterator: # Shard input eval data @@ -726,11 +723,7 @@ def train_loop(config, recorder, state=None): break with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): eval_metrics = p_eval_step(state, eval_batch, *step_rng_args) - eval_step_time_delta = datetime.datetime.now() - last_eval_step_completion - last_eval_step_completion = datetime.datetime.now() - metric_logger_instance.buffer_and_write_metrics( - eval_metrics, eval_step_count, step_time_delta=eval_step_time_delta, is_training=False - ) + metric_logger_instance.buffer_and_write_metrics(eval_metrics, eval_step_count, is_training=False) eval_step_count += 1 prof.maybe_deactivate_profiler(step, state) @@ -738,9 +731,6 @@ def train_loop(config, recorder, state=None): if step == start_step: max_utils.print_mem_stats("After params initialized") - last_step_completion = datetime.datetime.now() - metric_logger_instance.buffer_and_write_metrics(metrics, step, step_time_delta) - if config.save_checkpoint_on_completion: checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator) if checkpoint_manager is not None: