Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions src/maxtext/common/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# pytype: disable=attribute-error
"""Logger that saves metrics to a local file, GCS and TensorBoard."""

import datetime
import json
import os
import sys
Expand Down Expand Up @@ -111,10 +112,14 @@ def __init__(self, config, learning_rate_schedule):
if self.config.managed_mldiagnostics:
ManagedMLDiagnostics(config) # Initialize the MLRun instance.

self.last_flush_time = datetime.datetime.now()
self.last_eval_flush_time = datetime.datetime.now()

def reset_eval_metrics(self):
"""Resets the cumulative metrics dictionary for a new evaluation run."""
self.cumulative_eval_metrics = {"scalar": defaultdict(float)}
self._pending_eval_step_count = 0
self.last_eval_flush_time = datetime.datetime.now()

def write_metrics(self, metrics, step, metric_type="train"):
"""Entry point for all metrics writing. metric_type is one of 'train', 'eval', 'running_eval'."""
Expand Down Expand Up @@ -383,24 +388,36 @@ def buffer_and_write_metrics(self, metrics, step, step_time_delta=None, is_train
if self.buffered_metrics:
self._flush_one_buffered_entry(self.buffered_metrics.pop(0))
if is_training:
self.record_train_metrics(metrics, step, step_time_delta.total_seconds())
self.buffered_metrics.append(("train", step, metrics, step_time_delta))
self.buffered_metrics.append(("train", step, metrics))
if self._pending_eval_step_count > 0:
self._finalize_eval_metrics(step)
else:
self._pending_eval_step_count += 1
self.buffered_metrics.append(("eval", step, metrics, step_time_delta))
self.buffered_metrics.append(("eval", step, metrics))

def _flush_one_buffered_entry(self, entry):
"""Dispatches a single buffered entry to the writer."""
kind = entry[0]
if kind == "train":
_, step, metrics, _ = entry
_, step, metrics = entry
# Synchronize the loss before recording step time.
_ = float(metrics["scalar"]["learning/loss"])

current_time = datetime.datetime.now()
real_step_time = (current_time - self.last_flush_time).total_seconds()
self.last_flush_time = current_time

self.record_train_metrics(metrics, step, real_step_time)
self.write_metrics(metrics, step)
elif kind == "eval":
_, eval_step, raw_metrics, step_time_delta = entry
_, eval_step, raw_metrics = entry
# _accumulate_eval_metrics calls float() that materialize the metrics, deferred to here
self._accumulate_eval_metrics(raw_metrics)

current_time = datetime.datetime.now()
real_eval_step_time = (current_time - self.last_eval_flush_time).total_seconds()
self.last_eval_flush_time = current_time

running_count = eval_step + 1 # eval_step is 0-indexed
cumulative = self.cumulative_eval_metrics["scalar"]
running_avg_loss = cumulative["eval/total_loss"] / (cumulative["eval/total_weights"] + EPS)
Expand All @@ -411,7 +428,7 @@ def _flush_one_buffered_entry(self, entry):
"eval/total_weights": cumulative["eval/total_weights"],
"eval/avg_mtp_loss": cumulative["eval/mtp_loss"] / running_count,
"eval/avg_mtp_acceptance_rate_percent": (cumulative["eval/mtp_acceptance_rate_percent"] / running_count),
"eval/step_time_seconds": step_time_delta.total_seconds(),
"eval/step_time_seconds": real_eval_step_time,
}
}
self.write_metrics(snapshot, eval_step, metric_type="running_eval")
Expand Down
14 changes: 2 additions & 12 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# See github.com/google/maxtext/issues/20 for more

from typing import Any, Sequence
import datetime
import functools
import os

Expand Down Expand Up @@ -677,7 +676,6 @@ def train_loop(config, recorder, state=None):

_job_completed_gracefully = False
try:
last_step_completion = datetime.datetime.now()
for step in np.arange(start_step, config.steps):
prof.maybe_activate_profiler(step, state)

Expand All @@ -694,7 +692,7 @@ def train_loop(config, recorder, state=None):
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
state, metrics = p_train_step(state, example_batch, *step_rng_args)

step_time_delta = datetime.datetime.now() - last_step_completion
metric_logger_instance.buffer_and_write_metrics(metrics, step)

checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step)

Expand All @@ -717,7 +715,6 @@ def train_loop(config, recorder, state=None):
max_logging.log(f"Starting eval after train step {step}")

eval_step_count = 0
last_eval_step_completion = datetime.datetime.now()
# pylint: disable=not-callable
for eval_batch in eval_data_iterator:
# Shard input eval data
Expand All @@ -726,21 +723,14 @@ def train_loop(config, recorder, state=None):
break
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, *step_rng_args)
eval_step_time_delta = datetime.datetime.now() - last_eval_step_completion
last_eval_step_completion = datetime.datetime.now()
metric_logger_instance.buffer_and_write_metrics(
eval_metrics, eval_step_count, step_time_delta=eval_step_time_delta, is_training=False
)
metric_logger_instance.buffer_and_write_metrics(eval_metrics, eval_step_count, is_training=False)
eval_step_count += 1

prof.maybe_deactivate_profiler(step, state)

if step == start_step:
max_utils.print_mem_stats("After params initialized")

last_step_completion = datetime.datetime.now()
metric_logger_instance.buffer_and_write_metrics(metrics, step, step_time_delta)

if config.save_checkpoint_on_completion:
checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator)
if checkpoint_manager is not None:
Expand Down
Loading