Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ scan_pipeline_repeats: false
scan_layers_per_stage: false
set_remat_policy_on_pipeline_iterations: true
set_remat_policy_on_layers_per_stage: false
pipeline_save_decoder_layer_input: true # set to false to reduce pipeline tmem at cost of recomputing decoder layer inputs in backward pass


# Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp',
Expand Down
10 changes: 9 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,14 @@ class PipelineParallelism(BaseModel):
scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.")
set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.")
set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.")
pipeline_save_decoder_layer_input: bool = Field(
True,
description=(
"Whether to save 'decoder_layer_input' activations in the pipeline remat policy. "
"Setting to False reduces temporary memory (tmem) during pipeline execution at the cost "
"of recomputing decoder layer inputs in the backward pass."
),
)


class RematAndOffload(BaseModel):
Expand Down Expand Up @@ -2850,7 +2858,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
# For AOT compilation and correctness, always prioritize the 'stage' axis for sharding when pipelining.
for rule in self.logical_axis_rules:
if rule and rule[0] == "activation_embed_and_logits_batch":
rule[1] = ["stage", "data", "fsdp", "fsdp_transpose", "expert"]
rule[1] = [ax for ax in ["stage", "data", "fsdp", "fsdp_transpose", "expert"] if ax in self.mesh_axes]
break

if "stage" in self.mesh_axes:
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/kernels/gather_reduce_sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __getitem__(self, shape):
_BF16 = VectorTypeHelper(ir.BF16Type.get)


# fmt: off
@jax.jit(
static_argnames=[
"reduce_group_size",
Expand All @@ -69,6 +70,7 @@ def __getitem__(self, shape):
"topk_wgt_zero_nan",
],
)
# fmt: on
def sc_gather_reduce(
op: jax.Array,
idx: jax.Array,
Expand Down
25 changes: 16 additions & 9 deletions src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,13 +1624,22 @@ def _sequence_descriptor(segment_ids):
dummy_attn_mask = None
mask_type = "causal"
else:
# Default case: no packing, no context parallelism
dummy_attn_mask = jnp.zeros(
(1, 1, 1, self.max_target_length, self.max_target_length),
dtype=jnp.uint8,
)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8)
# Default case: no packing, no context parallelism.
# For synthetic data, segment IDs are always all-ones (one segment per sequence), so
# the segment mask is all-True and the combined mask reduces to pure causal masking.
# Use mask_type="causal" directly to avoid materializing f32/s32[seq,seq] tensors that
# XLA loop_broadcast_fusion hoists into the pipeline scan carry (+5 GiB temp memory).
if self.config.dataset_type == "synthetic":
attn_mask = None
dummy_attn_mask = None
mask_type = "causal"
else:
dummy_attn_mask = jnp.zeros(
(1, 1, 1, self.max_target_length, self.max_target_length),
dtype=jnp.uint8,
)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8)

dpa_layer = DotProductAttention(
head_dim=head_dim,
Expand All @@ -1643,12 +1652,10 @@ def _sequence_descriptor(segment_ids):
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout=qkv_layout,
scale_factor=1.0,
transpose_batch_sequence=False,
window_size=sliding_window_size,
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
context_parallel_axis=self.config.context_sharding,
context_parallel_strategy=self.config.context_parallel_strategy,
max_segments_per_seq=max_segments_per_seq,
)

Expand Down
1 change: 1 addition & 0 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def __init__(
mesh=mesh,
shard_mode=config.shard_mode,
debug_sharding=config.debug_sharding,
skip_trivial_specs=True,
)

def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
Expand Down
17 changes: 13 additions & 4 deletions src/maxtext/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax
from jax import lax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import NamedSharding, reshard
from maxtext.common.common_types import Array, DType, ShardMode
from maxtext.layers import nnx_wrappers
from maxtext.layers.initializers import Initializer, variable_to_logically_partitioned
Expand Down Expand Up @@ -78,7 +78,10 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->

if not self.with_scale:
if out_sharding is not None:
y = jax.lax.with_sharding_constraint(y, out_sharding)
if self.shard_mode == ShardMode.EXPLICIT:
y = reshard(y, out_sharding)
else:
y = jax.lax.with_sharding_constraint(y, out_sharding)
return y

scale = self.scale.get_value()
Expand All @@ -88,8 +91,14 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
scale = jax.device_put(scale, max_utils.device_space())

scale = jnp.asarray(scale, self.dtype)
effective_scale = scale + self.scale_offset
return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding)
effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale
y = y * effective_scale
if out_sharding is not None:
if self.shard_mode == ShardMode.EXPLICIT:
y = reshard(y, out_sharding)
else:
y = jax.lax.with_sharding_constraint(y, out_sharding)
return y


class GlobalRMSNorm(RMSNorm):
Expand Down
16 changes: 12 additions & 4 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
import jax.numpy as jnp
from jax.sharding import NamedSharding


import flax

try:
flax.config.update("flax_always_shard_variable", False)
except LookupError:
pass
from flax import linen as nn, nnx
from flax.linen import partitioning as nn_partitioning
from flax.nnx import variablelib
Expand Down Expand Up @@ -394,10 +401,11 @@ def diff_wrapper(curr_params, custom_params, rest, config, data):
(loss, (aux, new_rest)), (raw_grads, custom_grads) = grad_func(curr_params, custom_params, rest, config, data)
nnx.update(state.model, nnx.State.merge(custom_grads, new_rest))

raw_grads = jax.tree_util.tree_map(
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,
raw_grads,
)
if config.grad_dtype != jnp.float32:
raw_grads = jax.tree_util.tree_map(
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,
raw_grads,
)
if config.parameter_memory_host_offload:
raw_grads = jax.device_put(
raw_grads,
Expand Down
6 changes: 5 additions & 1 deletion src/maxtext/trainers/pre_train/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ def jit_and_compile(

def save_compiled(compiled, save_name):
"""Serialize and save the compiled function."""
serialized, _, _ = serialize(compiled)
result = serialize(compiled)
# jax.experimental.serialize_executable.serialize() changed its return type:
# older JAX: (bytes, in_tree, out_tree)
# newer JAX: bytes
serialized = result[0] if isinstance(result, tuple) else result
with open(save_name, "wb") as f:
f.write(serialized)

Expand Down
Loading
Loading