From 631941c0c4d70c98d23d8d59cb58dc87e7ec12d6 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Wed, 17 Jun 2026 18:10:16 +0000 Subject: [PATCH 1/2] feat(models): Integrate DeepSeek V4 architecture and routing --- src/maxtext/common/common_types.py | 1 + src/maxtext/configs/models/deepseek4-284b.yml | 64 ++++ src/maxtext/configs/types.py | 6 +- src/maxtext/layers/attention_compressed.py | 37 +-- src/maxtext/layers/attentions.py | 1 + src/maxtext/layers/decoders.py | 110 ++++++- src/maxtext/layers/embeddings.py | 17 +- src/maxtext/layers/moe.py | 19 +- src/maxtext/models/deepseek.py | 68 ++--- src/maxtext/models/deepseek4.py | 274 ++++++++++++++++++ src/maxtext/utils/globals.py | 1 + tests/unit/deepseek_v4_vs_reference_test.py | 39 +-- tests/unit/train_compile_test.py | 20 ++ 13 files changed, 578 insertions(+), 79 deletions(-) create mode 100644 src/maxtext/configs/models/deepseek4-284b.yml create mode 100644 src/maxtext/models/deepseek4.py diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index d4b52207fc..71dbc105d4 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -113,6 +113,7 @@ class DecoderBlockType(enum.Enum): SIMPLE_MLP = "simple_mlp" LLAMA4 = "llama4" OLMO3 = "olmo3" + DEEPSEEK4 = "deepseek4" class AttentionType(enum.Enum): diff --git a/src/maxtext/configs/models/deepseek4-284b.yml b/src/maxtext/configs/models/deepseek4-284b.yml new file mode 100644 index 0000000000..5ba2dd062f --- /dev/null +++ b/src/maxtext/configs/models/deepseek4-284b.yml @@ -0,0 +1,64 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Model config for DeepSeek-V4-Flash 284B (https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) + +base_emb_dim: 4096 +base_num_query_heads: 64 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 2048 +base_moe_mlp_dim: 2048 +vocab_size: 129280 +head_dim: 512 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 128 +indexer_n_heads: 64 +indexer_topk: 512 + +# Note: Layers (0, 1, 2) are prefix layers. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 3 prefix [0,0,4] + 40 scanned. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 256 +num_experts_per_tok: 6 +mlp_activations_limit: 10 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" + +# --- Attention configuration --- +attention_type: 'compressed' +q_lora_rank: 1024 +o_groups: 8 +o_lora_rank: 1024 +sliding_window_size: 128 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 1048576 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e43f34f247..d1f293aae8 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -227,7 +227,7 @@ class ProfilerType(str, Enum): "deepseek3-test", "deepseek3-tiny", "deepseek3.2-671b", - "deepseek4", + "deepseek4-284b", "deepseek-custom", "kimi-k2-1t", "gemma-7b", @@ -553,7 +553,7 @@ class Attention(BaseModel): "autoselected", description="The attention algorithm to use (dot_product, flash, etc).", ) - attention_type: Literal["global", "local_sliding", "chunk", "mla", "full"] = Field( + attention_type: Literal["global", "local_sliding", "chunk", "mla", "full", "compressed"] = Field( "global", description="The variant of attention to use." ) share_kv_projections: bool = Field( @@ -2925,6 +2925,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de raise ValueError("`local_checkpoint_period` must be > 0 for emergency checkpointing.") if self.moba and self.attention not in ("dot_product"): raise ValueError("MoBA is only supported with dot_product attention.") + if self.decoder_block == DecoderBlockType.DEEPSEEK4 and self.attention != "dot_product": + raise ValueError("DeepSeek4 decoder block currently only supports dot_product attention.") if self.use_indexer: if self.q_lora_rank == 0: raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.") diff --git a/src/maxtext/layers/attention_compressed.py b/src/maxtext/layers/attention_compressed.py index e9a25f46b5..391ec6cedd 100644 --- a/src/maxtext/layers/attention_compressed.py +++ b/src/maxtext/layers/attention_compressed.py @@ -680,24 +680,23 @@ def __init__( rngs: Optional[nnx.Rngs] = None, **kwargs, ): - """Initializes the CompressedAttention layer. + """Inherits all standard Attention hyperparameters and selectively instantiates + an underlying HCA or CSA compressor based on the provided `compress_ratio`. - Inherits all standard Attention hyperparameters and selectively instantiates - an underlying HCA or CSA compressor based on the provided `layer_type`. + Highlights of DeepSeek-V4 attention integration: + - Shared-KV: The layer supports decoupling Q and KV heads for heavy compression. + - MQA: Multi-Query Attention used alongside heavy KV compression. + - 3 Different Attention Modes: Sliding Window (prefix), HCA (128x), and CSA (4x). + - Dual RoPE Theta: Uses 10000 for standard uncompressed tokens and 160000 for compressed. Args: (See maxtext.layers.attentions.Attention for standard attention arguments) q_lora_rank: The rank for the LoRA projection in the compressed query. - compress_ratio: The compression ratio for the compressor. + compress_ratio: The compression ratio (0, 4, or 128) for the compressor. """ - """Initializes the Compressed Attention module.""" self.q_lora_rank = q_lora_rank self.compress_ratio = compress_ratio - # Determine the correct underlying attention type based on the compress_ratio - if self.compress_ratio == 0: - attention_type = AttentionType.LOCAL_SLIDING - super().__init__( config=config, num_query_heads=num_query_heads, @@ -809,20 +808,22 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No rngs=self.rngs, ) - # DeepSeek-V4 uses a separate RoPE theta (160000) for compressed tokens. - # We must instantiate a dedicated rotary embedding for the compressors - self.compress_rotary_embedding = DeepSeekV4RotaryEmbedding( + # Override the base rotary embedding with the correct theta for this layer. + # CSA / HCA layers use compressed_rope_max_timescale (160000). + # Sliding window prefix layers use rope_max_timescale (10000). + rope_theta = self.config.compressed_rope_max_timescale if self.compress_ratio > 0 else self.config.rope_max_timescale + self.rotary_embedding = DeepSeekV4RotaryEmbedding( head_dim=self.config.head_dim, - partial_rotary_factor=1.0, - rope_theta=self.config.compressed_rope_max_timescale, - dtype=self.dtype, + partial_rotary_factor=self.config.qk_rope_head_dim / self.config.head_dim, + rope_theta=rope_theta, + fprop_dtype=self.dtype, ) if self.compress_ratio > 4: self.hca_compressor = DeepseekV4HCACompressor( config=self.config, compress_ratio=self.compress_ratio, - rotary_embedding=self.compress_rotary_embedding, + rotary_embedding=self.rotary_embedding, kernel_init=self.kernel_init, quant=self.quant, model_mode=self.model_mode, @@ -832,7 +833,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No self.csa_compressor = DeepseekV4CSACompressor( config=self.config, compress_ratio=self.compress_ratio, - rotary_embedding=self.compress_rotary_embedding, + rotary_embedding=self.rotary_embedding, kernel_init=self.kernel_init, quant=self.quant, model_mode=self.model_mode, @@ -1047,7 +1048,7 @@ def __call__( # -> [batch, q_length, emb_dim] final_out = self.o_b_proj(grouped_flat) - return final_out + return final_out, None def compressed_attention( diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 679c891360..ab7673d1d4 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -850,6 +850,7 @@ def init_rotary_embedding(self): shard_mode=self.config.shard_mode, rngs=self.rngs, ) + elif self.is_qwen3_hybrid: rotary_embedding = PartialRotaryEmbedding( min_timescale=self.config.rope_min_timescale, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index b28b6dcb7a..0150c7b401 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -41,6 +41,7 @@ from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.models import ( deepseek, + deepseek4, deepseek_batchsplit, deepseek_batchsplit_fp8, gemma, @@ -467,6 +468,10 @@ def get_decoder_layers(self): deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen, ] + case DecoderBlockType.DEEPSEEK4: + return ( + [deepseek4.DeepSeek4ScannableBlockToLinen] if self.config.scan_layers else [deepseek4.DeepSeek4LayerToLinen] + ) case DecoderBlockType.GEMMA: return [gemma.GemmaDecoderLayerToLinen] case DecoderBlockType.GEMMA2: @@ -632,6 +637,7 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.MISTRAL, DecoderBlockType.MIXTRAL, DecoderBlockType.DEEPSEEK, + DecoderBlockType.DEEPSEEK4, DecoderBlockType.GEMMA, DecoderBlockType.GEMMA2, DecoderBlockType.GEMMA3, @@ -1061,6 +1067,17 @@ def __call__( previous_chunk, slot, ) + elif cfg.decoder_block == DecoderBlockType.DEEPSEEK4: + y = self._apply_deepseek4_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + slot, + decoder_input_tokens, + ) else: RemattedBlockLayer = RemattedBlockLayers[0] scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) @@ -1195,7 +1212,7 @@ def __call__( "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), } - if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): + if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5, DecoderBlockType.DEEPSEEK4): layer_kwargs = {"layer_idx": lyr} kv_cache = None if kv_caches is not None: @@ -1423,6 +1440,97 @@ def _apply_gemma4_scanned_blocks( return y + def _apply_deepseek4_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + slot, + decoder_input_tokens, + ): + """Applies DeepSeek V4 scanned decoder blocks. + + DeepSeek V4 has some number of prefix layers (defined by `first_num_hash_layers`) + that use static Hash Routing. The remaining layers alternate `compress_ratio=128` (HCA) + and `compress_ratio=4` (CSA) and are evaluated in a single `nn.scan` block. + + For DeepSeek4-Flash (43 hidden layers total): + - 3 Prefix layers (Indices 0, 1, 2) + - 40 Scanned layers: 20 perfectly repeating chunks of [128, 4] + """ + + cfg = self.config + mesh = self.mesh + + broadcast_args = ( + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot, + previous_chunk, + ) + + layer_call_kwargs = { + "previous_chunk": previous_chunk, + "slot": slot, + "decoder_input_tokens": decoder_input_tokens, + } + + # 1. Prefix Unrolling + # These layers use Hash Routing. + num_hash_layers = cfg.first_num_hash_layers + for layer_idx in range(num_hash_layers): + prefix_layer = deepseek4.DeepSeek4LayerToLinen( + config=cfg, + mesh=mesh, + name=f"layers_{layer_idx}", + quant=self.quant, + model_mode=self.model_mode, + layer_idx=layer_idx, + ) + y, _ = prefix_layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + **layer_call_kwargs, + ) + + # 2. Chunked Scanning + # The remaining layers perfectly alternate HCA (128) and CSA (4). + num_remaining_layers = cfg.num_decoder_layers - num_hash_layers + num_full_blocks = num_remaining_layers // 2 + + if num_full_blocks > 0: + ScannableBlockToLinen = deepseek4.DeepSeek4ScannableBlockToLinen + policy = self.get_remat_policy() + RemattedDeepSeek4Block = self.set_remat_policy([ScannableBlockToLinen], policy)[0] + + y, _ = nn.scan( + RemattedDeepSeek4Block, + variable_axes={ + "params": cfg.param_scan_axis, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True, "dropout": cfg.enable_dropout}, + in_axes=(nn.broadcast,) * len(broadcast_args), + length=num_full_blocks, + metadata_params={ + nn.PARTITION_NAME: "layers", + "abstract_init": False, + }, + )(config=cfg, mesh=mesh, quant=self.quant, model_mode=model_mode, name="scanned_blocks",)(y, *broadcast_args) + + return y + def _apply_gemma4_small_layers( self, y, diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 86b6723bd5..ad6b171f2f 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -1803,7 +1803,7 @@ def qwen3_omni_mrope_embedding_as_linen( ) -class DeepSeekV4RotaryEmbedding(nnx.Module): +class DeepSeekV4RotaryEmbedding(RotaryEmbedding): """DeepSeek-V4 partial rotary embedding with interleaved frequencies. DeepSeek-V4 uses an interleaved positional encoding where consecutive channels @@ -1822,12 +1822,23 @@ def __init__( head_dim: int, partial_rotary_factor: float = 64.0 / 512.0, rope_theta: float = 10000.0, - dtype: Any = jnp.float32, + fprop_dtype: Any = jnp.float32, + min_timescale: int = 10000, + max_timescale: int = 10000, + mesh: Any = None, + **kwargs, ): + super().__init__( + min_timescale=min_timescale, + max_timescale=max_timescale, + mesh=mesh, + fprop_dtype=fprop_dtype, + **kwargs, + ) self.head_dim = head_dim self.partial_rotary_factor = partial_rotary_factor self.rope_theta = rope_theta - self.dtype = dtype + self.fprop_dtype = fprop_dtype # Compute the partial rotary dimension (rope_head_dim) self.dim = int(head_dim * partial_rotary_factor) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 020956098c..4bb7cc7c08 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -208,6 +208,10 @@ def calculate_load_balance_updates(top_k_indices, num_experts, rate): return output +class Tid2EidVar(nnx.Variable): + """Custom variable to hold tid2eid without trainable param overhead.""" + + class GateLogit(nnx.Module): """A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing.""" @@ -399,8 +403,11 @@ def __init__( # DeepSeek V4 Hash Routing if self.is_hash_routing: # Token-ID to Expert-ID lookup table for static routing - self.tid2eid = nnx.Variable( - jnp.zeros((self.config.vocab_size, self.num_experts_per_tok), dtype=jnp.int32), + # Must be stored as float32 because MaxText passes the entire variable tree + # through jax.value_and_grad, which strictly requires all leaves to be inexact types + # (even if they receive no gradients). We cast to int32 dynamically during routing. + self.tid2eid = Tid2EidVar( + jnp.zeros((self.config.vocab_size, self.num_experts_per_tok), dtype=jnp.float32), out_sharding=None, # Replicated across shards for local lookup ) else: @@ -665,7 +672,13 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None, input_ids=None): return top_k_weights, top_k_indices if self.is_hash_routing: - top_k_indices = self.tid2eid[input_ids] + if input_ids is None: + raise ValueError("input_ids cannot be None when is_hash_routing is True") + # Access the static routing table + tid2eid_int = self.tid2eid.value + # Cast the float32 array to int32 (JAX automatically assigns 0.0 gradients to integer casts) + tid2eid_int = tid2eid_int.astype(jnp.int32) + top_k_indices = tid2eid_int[input_ids] top_k_weights = jnp.take_along_axis(pre_bias_logits, top_k_indices, axis=-1) # NOTE: deepseek2 has a different pattern elif self.config.model_name.startswith(("deepseek3", "deepseek4")): diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 27e1a6f7ad..d3a72b31bf 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -25,7 +25,7 @@ import jax.numpy as jnp from jax.sharding import Mesh from maxtext.common.common_types import Config -from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL +from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL, DecoderBlockType from maxtext.layers import attention_mla from maxtext.layers import initializers from maxtext.layers import linears @@ -138,37 +138,39 @@ def __init__( self.engram_layer_norm = None self.engram = None - self.self_attention = attention_mla.MLA( - config=self.config, - num_query_heads=self.config.num_query_heads, - num_kv_heads=self.config.num_kv_heads, - head_dim=self.config.head_dim, - max_target_length=self.config.max_target_length, - max_prefill_predict_length=self.config.max_prefill_predict_length, - attention_kernel=self.config.attention, - attention_type=self.config.attention_type, - inputs_q_shape=self.dummy_inputs_shape, - inputs_kv_shape=self.dummy_inputs_shape, - mesh=mesh, - dtype=self.config.dtype, - weight_dtype=self.config.weight_dtype, - dropout_rate=self.config.dropout_rate, - name="self_attention", - quant=quant, - kv_quant=quantizations.configure_kv_quant(config), - q_lora_rank=self.config.q_lora_rank, - kv_lora_rank=self.config.kv_lora_rank, - qk_nope_head_dim=self.config.qk_nope_head_dim, - qk_rope_head_dim=self.config.qk_rope_head_dim, - v_head_dim=self.config.v_head_dim, - max_position_embeddings=self.config.max_position_embeddings, - original_max_position_embeddings=self.config.original_max_position_embeddings, - mscale=self.config.mscale, - rope_factor=self.config.rope_factor, - model_mode=model_mode, - rngs=rngs, - attn_logits_soft_cap=self.config.attn_logits_soft_cap, - ) + # DeepSeek V4 natively overrides this block with CompressedAttention. + if self.config.decoder_block != DecoderBlockType.DEEPSEEK4: + self.self_attention = attention_mla.MLA( + config=self.config, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=self.config.attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + name="self_attention", + quant=quant, + kv_quant=quantizations.configure_kv_quant(self.config), + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + original_max_position_embeddings=self.config.original_max_position_embeddings, + mscale=self.config.mscale, + rope_factor=self.config.rope_factor, + model_mode=model_mode, + rngs=rngs, + attn_logits_soft_cap=self.config.attn_logits_soft_cap, + ) self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) if self.is_mhc_enabled: @@ -333,7 +335,7 @@ def __init__( rngs=self.rngs, ) - def mlp_op(self, x, deterministic): + def mlp_op(self, x, deterministic, *args, **kwargs): mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding) return self.with_logical_constraint(mlp) diff --git a/src/maxtext/models/deepseek4.py b/src/maxtext/models/deepseek4.py new file mode 100644 index 0000000000..12b0b83823 --- /dev/null +++ b/src/maxtext/models/deepseek4.py @@ -0,0 +1,274 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepSeek-V4 model definition.""" + +from typing import Optional + +from flax import nnx +import flax.linen as nn +from jax.sharding import Mesh + +from maxtext.common.common_types import Config, AttentionType +from maxtext.common.common_types import HyperConnectionType +from maxtext.layers import attention_compressed +from maxtext.layers import initializers +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.models import deepseek +from jax.ad_checkpoint import checkpoint_name + + +class DeepSeek4DecoderLayer(deepseek.DeepSeekGenericLayer): + """DeepSeek-V4 specific decoder layer. + + Note: V4 does not utilize purely dense layers in the initial transformer blocks. + Every layer is a Sparse MoE layer (which internally contains shared dense experts). + + Args: + config: Configuration for the model. + model_mode: The mode of the model (e.g. 'train', 'inference'). + mesh: JAX sharding mesh. + rngs: NNX Rngs. + quant: Optional AQT quantization config. + layer_idx: The index of the layer. + compress_ratio: DeepSeek V4 specific parameter defining the KV cache compression + ratio. Expected values are 0 (no compression, sliding window), 4 (CSA), or 128 (HCA). + is_hash_routing: DeepSeek V4 specific parameter defining if this layer uses + static deterministic hash routing (used in prefix layers). + """ + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + compress_ratio: Optional[int] = None, + is_hash_routing: Optional[bool] = None, + ) -> None: + super().__init__( + config=config, + model_mode=model_mode, + mesh=mesh, + rngs=rngs, + quant=quant, + layer_idx=layer_idx, + ) + + # DeepSeek V4 applies Hash Routing to the first `config.first_num_hash_layers` layers. + # For the unscannable prefix layers, we can safely determine this using `layer_idx`. + # However, for layers inside `nn.scan` blocks, `layer_idx` is a dynamic JAX tracer + # and cannot be evaluated as a boolean condition. Since all scannable layers occur + # after the hash-routed prefix, the scannable block explicitly passes + # `is_hash_routing=False` to safely bypass this check. + if is_hash_routing is None: + is_hash_routing = layer_idx < config.first_num_hash_layers + self.mlp = moe.RoutedAndSharedMoE( + config=self.config, + mesh=self.mesh, + kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + quant=quant, + is_hash_routing=is_hash_routing, + rngs=rngs, + ) + + if compress_ratio is None: + compress_ratio = config.compress_ratios[layer_idx] + + # Route to LOCAL_SLIDING if compression is disabled for this layer, + # otherwise default to the globally configured attention type (e.g., COMPRESSED). + layer_attention_type = ( + AttentionType.LOCAL_SLIDING if compress_ratio == 0 else AttentionType(self.config.attention_type) + ) + + self.self_attention = attention_compressed.CompressedAttention( + config=self.config, + compress_ratio=compress_ratio, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=layer_attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=self.mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + sliding_window_size=self.config.sliding_window_size, + q_lora_rank=self.config.q_lora_rank, + name=f"compressed_attention_layer_{layer_idx}", + quant=quant, + kv_quant=quantizations.configure_kv_quant(config), + model_mode=model_mode, + rngs=rngs, + ) + + # pylint: disable=arguments-differ + def mlp_op(self, inputs, deterministic, *args, **kwargs): + input_ids = kwargs.get("input_ids") + mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp( + inputs=inputs, + input_ids=input_ids, + ) + return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + kv_cache=None, + attention_metadata=None, + decoder_input_tokens=None, + ): + if isinstance(inputs, tuple): + inputs = inputs[0] + + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + _, intermediate_inputs = self.self_attention_with_norm_op( + x, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + slot, + ) + + layer_output, metadata = self.mhc_mlp( + self.post_attention_norm_op, + self.mlp_op, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_MOE, + deterministic=deterministic, + input_ids=decoder_input_tokens, + ) + load_balance_loss = metadata.get("load_balance_loss", None) + moe_bias_updates = metadata.get("moe_bias_updates", None) + + layer_output = self.dropout_op(layer_output, deterministic=deterministic) + return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache) + + +class DeepSeek4ScannableBlock(nnx.Module): + """A scannable block containing exactly two DeepSeek V4 layers (HCA and CSA). + + DeepSeek V4 layers alternate `compress_ratio=128` (HCA) and `compress_ratio=4` (CSA) + throughout the middle of the network. This block encapsulates one full `[128, 4]` + cycle so it can be perfectly scanned using JAX `nn.scan`. + """ + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: None | quantizations.AqtQuantization = None, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + # Layer 0 in the block: HCA (compress_ratio=128) with Standard MoE (is_hash_routing=False) + self.layers_0 = DeepSeek4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + compress_ratio=128, + is_hash_routing=False, + ) + + # Layer 1 in the block: CSA (compress_ratio=4) with Standard MoE (is_hash_routing=False) + self.layers_1 = DeepSeek4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + compress_ratio=4, + is_hash_routing=False, + ) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot=None, + previous_chunk=None, + attention_metadata=None, + kv_cache=None, + ): + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = checkpoint_name(inputs, "decoder_layer_input") + y = inputs + + y, _ = self.layers_0( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + y, _ = self.layers_1( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + return y, None + + +DeepSeek4LayerToLinen = nnx_wrappers.to_linen_class( + DeepSeek4DecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + +DeepSeek4ScannableBlockToLinen = nnx_wrappers.to_linen_class( + DeepSeek4ScannableBlock, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index e3b3aadf2d..48caa91ef1 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -75,6 +75,7 @@ "deepseek2-16b": "deepseek-ai/DeepSeek-V2-Lite", "deepseek3-671b": "deepseek-ai/DeepSeek-V3", "deepseek3.2-671b": "deepseek-ai/DeepSeek-V3.2", + "deepseek4": "deepseek-ai/DeepSeek-V4-Flash", "gpt-oss-20b": "openai/gpt-oss-20b", "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py index 1da95a184e..0b75aa9ff4 100644 --- a/tests/unit/deepseek_v4_vs_reference_test.py +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -57,13 +57,13 @@ # Tests # ============================================================================== -# HuggingFace reference: https://huggingface.co/deepseek-ai/DeepSeek-V4/blob/main/modeling_deepseek_v4.py +# HuggingFace reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py # pylint: disable=line-too-long from jax.experimental import mesh_utils from jax.sharding import Mesh from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.configs import pyconfig from maxtext.layers.attention_compressed import CompressedAttention -from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding as MTRope + from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.deepseek_v4.modeling_deepseek_v4 import DeepseekV4Attention from transformers.models.deepseek_v4.modeling_deepseek_v4 import DeepseekV4RotaryEmbedding as PTRope @@ -75,7 +75,7 @@ class DeepSeekV4RotaryEmbeddingTest(unittest.TestCase): def setUp(self): self.batch_size = 2 - self.seq_len = 16 + self.seq_len = 4096 self.head_dim = 128 self.num_heads = 4 self.main_rope_theta = 10000.0 @@ -408,6 +408,8 @@ def setUp(self): self.q_lora_rank = 32 self.o_groups = 2 self.o_lora_rank = 64 + self.qk_rope_head_dim = 64 + self.partial_rotary_factor = self.qk_rope_head_dim / self.head_dim self.rngs = nnx.Rngs(0) @@ -431,8 +433,12 @@ def setUp(self): layer_types=["sliding_attention"], num_hidden_layers=1, rope_parameters={ - "main": {"rope_type": "default", "rope_theta": 10000.0, "partial_rotary_factor": 1.0}, - "compress": {"rope_type": "default", "rope_theta": 160000.0, "partial_rotary_factor": 1.0}, + "main": {"rope_type": "default", "rope_theta": 10000.0, "partial_rotary_factor": self.partial_rotary_factor}, + "compress": { + "rope_type": "default", + "rope_theta": 160000.0, + "partial_rotary_factor": self.partial_rotary_factor, + }, }, sliding_window=2048, attention_dropout=0.0, @@ -524,9 +530,13 @@ def _run_e2e_test(self, layer_type, is_packed=False): "compressed_sparse_attention": self.pt_config.compress_rates["compressed_sparse_attention"], "heavily_compressed_attention": self.pt_config.compress_rates["heavily_compressed_attention"], } + compress_ratio = compress_ratio_map[layer_type] + layer_attention_type = AttentionType.LOCAL_SLIDING if compress_ratio == 0 else AttentionType.COMPRESSED + mt_attn = CompressedAttention( config=mt_config, - compress_ratio=compress_ratio_map[layer_type], + compress_ratio=compress_ratio, + attention_type=layer_attention_type, num_query_heads=self.num_heads, num_kv_heads=1, head_dim=self.head_dim, @@ -540,14 +550,6 @@ def _run_e2e_test(self, layer_type, is_packed=False): rngs=self.rngs, ) self.mt_attn = mt_attn - if layer_type == "sliding_attention": - rope_factor = self.pt_config.rope_parameters["main"]["partial_rotary_factor"] - mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=10000.0) - else: - rope_factor = self.pt_config.rope_parameters["compress"]["partial_rotary_factor"] - mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=160000.0) - - mt_attn.rotary_embedding = mt_rope # 3. Copy Weights self._copy_linear(mt_attn.wq_a, ref_attn.q_a_proj) @@ -652,8 +654,7 @@ def _run_e2e_test(self, layer_type, is_packed=False): print(f"top_k_indices mismatches: {num_mismatches}") # 6. Execute MaxText - - mt_out = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) + mt_out, _ = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) # 7. Asserts if not is_packed: @@ -771,7 +772,7 @@ def setUp(self): "vocab_size": self.vocab_size, "first_num_hash_layers": 3, "decoder_block": "deepseek", - "model_name": "deepseek4", + "model_name": "deepseek4-284b", "attention": "dot_product", "base_mlp_dim": 256, "base_moe_mlp_dim": 256, @@ -809,7 +810,7 @@ def test_hash_router(self): ) # Sync weights - mx_moe.tid2eid.value = jnp.array(pt_router.tid2eid.numpy()) + mx_moe.tid2eid.value = jnp.array(pt_router.tid2eid.numpy(), dtype=jnp.float32) mx_moe.gate.kernel.value = jnp.array(pt_router.weight.detach().numpy()).T hidden_states = torch.randn(self.batch_size, self.seq_len, self.hidden_dim) @@ -910,7 +911,7 @@ def test_swiglu_clamp(self): "topk_routing_group": 1, "mlp_activations_limit": limit, "decoder_block": "deepseek", - "model_name": "deepseek4", + "model_name": "deepseek4-284b", "attention": "dot_product", "base_mlp_dim": 256, "base_moe_mlp_dim": 256, diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 1975ad1abf..41557c8c3c 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -804,6 +804,26 @@ def test_deepseek32(self): ) ) + def test_deepseek4(self): + # test deepseek4 compile + compiled_trainstep_file = "/tmp/test_deepseek4.pickle" + train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=deepseek4-284b", + "per_device_batch_size=1", + "max_target_length=1024", + "attention=dot_product", + "dtype=bfloat16", + "weight_dtype=bfloat16", + ) + ) + @pytest.mark.cpu_only def test_indexer_dense_warmup(self): # test deepseek3.2 with sparse attention From cd0a54e0431664f8fafea13882ee169d949f3735 Mon Sep 17 00:00:00 2001 From: Dipak Gaikwad Date: Wed, 17 Jun 2026 20:00:50 +0000 Subject: [PATCH 2/2] Enabled auxilillary loss free load balancing and sequence wise load balancing for Deepseek. Tested by running training loop with new tiny Deeepseek V4 model added as part of the commit, here are the logs for testing Without load balancing active logs : https://paste.googleplex.com/6421399878107136 with load balancing logs : https://paste.googleplex.com/6551357300539392 Here are the results actived for reducing the varience : 1 === DeepSeek V4 Load Balancing Variance Analysis (Step 0 vs Step 20) === 2 3 | Layer Index | Routing Type | Step 0 Var (Baseline) | Step 20 Var (Run A) | Step 20 Var (Run B) | Improvement (A vs B) | 4 |-------------|--------------|-----------------------|---------------------|---------------------|----------------------| 5 | 0 | Hash Routed | 3932160.00 | 3932160.00 | 3932160.00 | 0.00% | 6 | 1 | Hash Routed | 3932160.00 | 3932160.00 | 3932160.00 | 0.00% | 7 | 2 | Hash Routed | 3932160.00 | 3932160.00 | 3932160.00 | 0.00% | 8 | 3 | Top-K Routed | 7409.38 | 7509.25 | 3672.12 | 51.10% | 9 | 4 | Top-K Routed | 3158.38 | 3230.12 | 1216.00 | 62.35% | 10 | 5 | Top-K Routed | 5713.38 | 5772.75 | 2359.38 | 59.13% | 11 | 6 | Top-K Routed | 8295.25 | 8082.50 | 3674.12 | 54.54% | 12 | 7 | Top-K Routed | 4765.62 | 4614.62 | 1212.75 | 73.72% | 13 | 8 | Top-K Routed | 4960.75 | 4923.12 | 1663.50 | 66.21% | 14 | 9 | Top-K Routed | 3905.50 | 3816.25 | 1316.88 | 65.49% | 15 | 10 | Top-K Routed | 5057.00 | 4981.12 | 2257.75 | 54.67% | 16 | 11 | Top-K Routed | 10446.62 | 10381.62 | 5565.75 | 46.39% | 17 | 12 | Top-K Routed | 9538.50 | 9529.25 | 5319.12 | 44.18% | 18 | 13 | Top-K Routed | 7031.38 | 7131.25 | 3270.25 | 54.14% | 19 | 14 | Top-K Routed | 4852.00 | 4900.12 | 1906.88 | 61.09% | 20 | 15 | Top-K Routed | 9306.12 | 9342.88 | 4733.75 | 49.33% | 21 | 16 | Top-K Routed | 5811.25 | 5749.50 | 2110.88 | 63.29% | 22 | 17 | Top-K Routed | 6715.62 | 6874.25 | 2664.12 | 61.24% | 23 | 18 | Top-K Routed | 8145.50 | 7869.25 | 3383.75 | 57.00% | 24 | 19 | Top-K Routed | 6042.12 | 5908.62 | 2353.00 | 60.18% | 25 | 20 | Top-K Routed | 8559.88 | 8158.25 | 4333.38 | 46.88% | 26 | 21 | Top-K Routed | 11742.25 | 11943.62 | 7563.50 | 36.67% | 27 | 22 | Top-K Routed | 4959.62 | 5014.88 | 1998.62 | 60.15% | 28 | 23 | Top-K Routed | 7717.12 | 7751.88 | 3879.88 | 49.95% | 29 | 24 | Top-K Routed | 9017.75 | 9307.88 | 4702.75 | 49.48% | 30 | 25 | Top-K Routed | 14127.12 | 14111.25 | 8079.25 | 42.75% | 31 | 26 | Top-K Routed | 5074.25 | 5194.12 | 1675.50 | 67.74% | 32 | 27 | Top-K Routed | 11919.50 | 11204.38 | 6470.75 | 42.25% | 33 | 28 | Top-K Routed | 12241.75 | 12998.62 | 7624.12 | 41.35% | 34 | 29 | Top-K Routed | 9384.50 | 9005.00 | 5052.00 | 43.90% | 35 | 30 | Top-K Routed | 9698.62 | 9678.25 | 5231.75 | 45.94% | 36 | 31 | Top-K Routed | 12244.25 | 12392.75 | 7249.25 | 41.50% | 37 | 32 | Top-K Routed | 10030.00 | 9972.62 | 4755.50 | 52.31% | 38 | 33 | Top-K Routed | 7265.00 | 6973.62 | 3271.75 | 53.08% | 39 | 34 | Top-K Routed | 11945.50 | 11940.62 | 6076.88 | 49.11% | 40 | 35 | Top-K Routed | 12917.50 | 13740.00 | 7210.62 | 47.52% | 41 | 36 | Top-K Routed | 15011.62 | 15083.00 | 8870.62 | 41.19% | 42 | 37 | Top-K Routed | 10294.12 | 10176.25 | 5907.50 | 41.95% | 43 | 38 | Top-K Routed | 8928.62 | 9236.00 | 5136.62 | 44.38% | 44 | 39 | Top-K Routed | 15633.62 | 15171.00 | 9684.75 | 36.16% | 45 | 40 | Top-K Routed | 7687.75 | 7658.12 | 4521.25 | 40.96% | 46 | 41 | Top-K Routed | 12485.12 | 12270.38 | 6933.25 | 43.50% | 47 | 42 | Top-K Routed | 17641.25 | 17163.50 | 10974.12 | 36.06% | 48 |-------------|--------------|-----------------------|---------------------|---------------------|----------------------| 49 | TOTAL/AVG | Top-K Only | 357681.12 | 356762.50 | 185883.62 | 47.90% | Raw data collected for this analysis: https://paste.googleplex.com/5060754624610304 https://paste.googleplex.com/5473518849490944 --- src/maxtext/common/metric_logger.py | 2 +- src/maxtext/configs/models/deepseek4-284b.yml | 4 + src/maxtext/configs/models/deepseek4-tiny.yml | 69 +++++++++++ src/maxtext/configs/types.py | 3 +- src/maxtext/layers/moe.py | 6 +- src/maxtext/layers/quantizations.py | 2 +- src/maxtext/optimizers/optimizers.py | 15 +++ src/maxtext/trainers/pre_train/train.py | 79 +++++++++--- tests/unit/deepseek_routed_bias_test.py | 112 ++++++++++++++++++ tests/unit/metric_logger_test_coverage.py | 29 +++++ tests/unit/optimizers_test.py | 39 ++++++ tests/unit/train_nnx_test.py | 18 ++- 12 files changed, 354 insertions(+), 24 deletions(-) create mode 100644 src/maxtext/configs/models/deepseek4-tiny.yml create mode 100644 tests/unit/deepseek_routed_bias_test.py create mode 100644 tests/unit/metric_logger_test_coverage.py diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 44771ecb05..2f1a564c6d 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -197,7 +197,7 @@ def _log_training_metrics(self, metrics, step): if self.config.num_experts > 1: moe_lb_loss = scalars.get("learning/moe_lb_loss", 0.0) - log_parts.append(f"moe_lb_loss: {moe_lb_loss:.3f}") + log_parts.append(f"moe_lb_loss: {moe_lb_loss:.6f}") if self.config.mtp_num_layers > 0: mtp_loss = scalars.get("learning/mtp_loss", 0.0) diff --git a/src/maxtext/configs/models/deepseek4-284b.yml b/src/maxtext/configs/models/deepseek4-284b.yml index 5ba2dd062f..5cded0be59 100644 --- a/src/maxtext/configs/models/deepseek4-284b.yml +++ b/src/maxtext/configs/models/deepseek4-284b.yml @@ -48,6 +48,10 @@ num_experts_per_tok: 6 mlp_activations_limit: 10 shared_experts: 1 routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] # --- Attention configuration --- attention_type: 'compressed' diff --git a/src/maxtext/configs/models/deepseek4-tiny.yml b/src/maxtext/configs/models/deepseek4-tiny.yml new file mode 100644 index 0000000000..881043777b --- /dev/null +++ b/src/maxtext/configs/models/deepseek4-tiny.yml @@ -0,0 +1,69 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny model config for DeepSeek V4 for CPU execution and testing + +base_emb_dim: 64 +base_num_query_heads: 4 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 64 +base_moe_mlp_dim: 64 +vocab_size: 129280 +head_dim: 32 +qk_rope_head_dim: 32 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 32 +indexer_n_heads: 4 +indexer_topk: 16 + +# Note: Layers (0,1) are not compressed. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4]. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 16 +num_experts_per_tok: 4 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] + +# --- Attention configuration --- +attention: 'dot_product' +attention_type: 'compressed' +q_lora_rank: 16 +o_groups: 4 +o_lora_rank: 16 +sliding_window_size: 32 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 4096 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index d1f293aae8..0cbf5d751a 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -228,6 +228,7 @@ class ProfilerType(str, Enum): "deepseek3-tiny", "deepseek3.2-671b", "deepseek4-284b", + "deepseek4-tiny", "deepseek-custom", "kimi-k2-1t", "gemma-7b", @@ -2988,7 +2989,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1: raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.") - if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK: + if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block not in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK4): raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.") if self.model_name.startswith("deepseek4") and self.first_num_hash_layers > 0 and self.use_ring_of_experts: raise ValueError("DeepSeek V4 hash routing is currently not supported with ring of experts.") diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 4bb7cc7c08..61b4b6db75 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -348,8 +348,11 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. pre_bias_logits = output if self.use_bias: + # Architectural Note: Bias is an nnx.Param rather than nnx.Variable due to Linen/NNX state + # management transitions otherwise we will have to manage the overhead. We use jax.lax.stop_gradient + # here to mathematically enforce the Auxiliary-Loss-Free constraint, isolating it from sequence-wise loss leaks. bias = jnp.asarray(self.bias[...], self.dtype) - output += bias + output += jax.lax.stop_gradient(bias) return output, pre_bias_logits @@ -2163,7 +2166,6 @@ def dense_matmul( lb_loss = ( self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None ) - # TODO(dipakg-lang, b/521990776): Add sequence-wise balance loss * 0.0001 else: lb_loss = None diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 95bd79eb9f..86b61c7480 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -38,7 +38,7 @@ import qwix from qwix._src.core import dot_general_qt from qwix._src.core import sparsity -from qwix._src.utils import flax_util +from qwix._src import flax_util import qwix.pallas as qpl # Params used to define mixed precision quantization configs diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 9992d7674f..4200504927 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -238,6 +238,21 @@ def get_optimizer(config, learning_rate_schedule, model=None): lambda params: jax.tree_util.tree_map(lambda x: "frozen" if x else "trainable", freeze_mask_fn(params)), ) + if getattr(config, "routed_bias", False): + import re + from flax import traverse_util + bias_regex = re.compile(".*gate.*bias.*") + # Architectural Note: Optax's Muon implementation correctly routes 2D+ matrices to the + # Newton-Schulz algorithm, but its fallback logic for 1D vectors (like our GateLogit bias) + # routes them to a standard AdamW optimizer *without* exposing a weight decay mask. + # To prevent the Muon optimizer from decaying our auxiliary-loss-free bias to zero, + # we apply a global optax.set_to_zero() mask here. + def bias_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + mask = {k: bool(bias_regex.match("/".join(map(str, k)))) for k in flat_params.keys()} + return traverse_util.unflatten_dict(mask) + base_opt = optax.chain(base_opt, optax.masked(optax.set_to_zero(), bias_mask_fn)) + return base_opt diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index fd2cc7b56c..c0a0be3a04 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -36,7 +36,7 @@ import jax.numpy as jnp from jax.sharding import NamedSharding -from flax import linen as nn, nnx +from flax import linen as nn, nnx, traverse_util from flax.linen import partitioning as nn_partitioning from flax.nnx import variablelib @@ -278,12 +278,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr else: max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.") - # get MoE routed bias term updates - moe_bias_updates = None - if config.routed_bias and config.routed_bias_update_rate > 0.0: - nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates") - moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None) - # Add the model's primary output to the intermediates dict so it can be used # by the acceptance rate calculation in eval_step. intermediate_outputs["logits"] = logits @@ -295,7 +289,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "indexer_loss": indexer_loss, - "moe_bias_updates": moe_bias_updates, "mtp_loss": mtp_loss, "batch_stats": (intermediate_outputs.get("batch_stats", None) if hasattr(intermediate_outputs, "get") else None), } @@ -421,9 +414,9 @@ def diff_wrapper(curr_params, custom_params, rest, config, data): moe_lb_loss = aux["moe_lb_loss"] indexer_loss = aux.get("indexer_loss", 0.0) z_loss = aux.get("z_loss", 0.0) - moe_bias_updates = aux.get("moe_bias_updates") mtp_loss = aux.get("mtp_loss", 0.0) new_opt_state = None + bias_metrics = {} if isinstance(model, nn.Module): if config.gradient_clipping_threshold > 0: @@ -480,12 +473,39 @@ def move(path, value): else: new_state = state.apply_gradients(grads=full_grads) - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for the DeepSeek family. + # We dynamically traverse the PyTree to apply updates because the topology varies drastically: + # 1. DeepSeek V3 mixes dense layers (no bias updates) with MoE layers. + # 2. DeepSeek V4 introduces Hash Routing in early layers (which lack a learnable bias entirely). + # 3. DeepSeek V4 groups alternating attention topologies into nested `ScannableBlocks`. + # Dynamic traversal ensures we only target the correct `gate.bias` parameters without hardcoded, brittle paths. + if config.routed_bias and config.routed_bias_update_rate > 0.0: + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + flat_params = traverse_util.flatten_dict(new_state.params) + new_flat_params = dict(flat_params) + + for path, update in flat_intermediates.items(): + if path[-1] != "moe_bias_updates": + continue + prefix = path[1:-1] if path[0] == "intermediates" else path[:-1] + for param_path in flat_params: + param_prefix = param_path[1:] if param_path[0] == "params" else param_path + if len(param_prefix) >= len(prefix) and param_prefix[:len(prefix)] == prefix and param_path[-2:] == ("gate", "bias"): + update_val = update[0] if isinstance(update, (tuple, list)) else update + name_prefix = "-".join(map(str, param_path)) + + old_val = new_flat_params[param_path].value if hasattr(new_flat_params[param_path], "value") else new_flat_params[param_path] + bias_metrics[f"learning/moe_bias_before_norm_{name_prefix}"] = jnp.linalg.norm(old_val) + + new_val = old_val + jnp.array(update_val).transpose() + if hasattr(new_flat_params[param_path], "value"): + new_flat_params[param_path] = new_flat_params[param_path].replace(value=new_val) + else: + new_flat_params[param_path] = new_val + + bias_metrics[f"learning/moe_bias_update_norm_{name_prefix}"] = jnp.linalg.norm(jnp.array(update_val)) + + new_state = new_state.replace(params=traverse_util.unflatten_dict(new_flat_params)) else: if config.gradient_clipping_threshold > 0: grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) @@ -506,9 +526,31 @@ def move(path, value): new_state = state # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias - target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() + if config.routed_bias and config.routed_bias_update_rate > 0.0: + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + print("DEBUG_FLAT_INTERMEDIATES:", list(flat_intermediates.keys())) + for path, update in flat_intermediates.items(): + if path[-1] != "moe_bias_updates": + continue + target = new_state.model + prefix = path[1:-1] if path[0] == "intermediates" else path[:-1] + for key in prefix: + if hasattr(target, key): + target = getattr(target, key) + elif isinstance(target, dict) and key in target: + target = target[key] + else: + target = None + break + if target is None: + continue + for _, node in nnx.iter_graph(target): + if type(node).__name__ == "GateLogit" and hasattr(node, "bias") and node.bias is not None: + update_val = update[0] if isinstance(update, (tuple, list)) else update + name_prefix = "-".join(map(str, prefix)) + bias_metrics[f"learning/moe_bias_before_norm_{name_prefix}"] = jnp.linalg.norm(node.bias.value) + node.bias.value = node.bias.value + jnp.array(update_val).transpose() + bias_metrics[f"learning/moe_bias_update_norm_{name_prefix}"] = jnp.linalg.norm(jnp.array(update_val)) lm_loss = xent_sum / (total_weights + EPS) scalar_metrics = { @@ -521,6 +563,7 @@ def move(path, value): "learning/mtp_loss": mtp_loss, "learning/total_weights": total_weights, } + scalar_metrics.update(bias_metrics) if config.use_qk_clip: if isinstance(model, nn.Module): new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) diff --git a/tests/unit/deepseek_routed_bias_test.py b/tests/unit/deepseek_routed_bias_test.py new file mode 100644 index 0000000000..9237f89f8b --- /dev/null +++ b/tests/unit/deepseek_routed_bias_test.py @@ -0,0 +1,112 @@ +import unittest +import jax +import jax.numpy as jnp +import optax +from flax.training import train_state +from maxtext.configs import pyconfig +from maxtext.models import models +from maxtext.trainers.pre_train import train as pre_train +class DeepSeekRoutedBiasTest(unittest.TestCase): + def setUp(self): + self.mesh = jax.sharding.Mesh(jax.devices(), ('data',)) + def _make_dummy_data(self, batch=1, seq=16): + return { + "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), + "inputs_position": jnp.broadcast_to(jnp.arange(seq), (batch, seq)), + "inputs_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + "targets": jnp.zeros((batch, seq), dtype=jnp.int32), + "targets_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + } + def _create_and_run_train_step(self, config_args): + config = pyconfig.initialize(config_args) + rngs = jax.nnx.Rngs(0) if hasattr(jax, 'nnx') else __import__('flax.nnx', fromlist=['Rngs']).Rngs(0) + import flax.nnx as nnx + from maxtext.common import train_state_nnx + rngs = nnx.Rngs(0) + model = models.Transformer(config, self.mesh, quant=None, rngs=rngs) + data = self._make_dummy_data(batch=config.micro_batch_size_to_train_on, seq=config.max_target_length) + optimizer = nnx.Optimizer(model, optax.sgd(0.01), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + state_graphdef, state_pure = nnx.split(ts) + new_state, metrics = pre_train.train_step( + state_graphdef, config, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + return new_state, metrics + def test_deepseek_v3_dense_routed_bias_success(self): + """Proves that a DeepSeek V3 model with dense layers (no moe_layers attribute) + successfully traverses the state tree and updates routed bias without crashing. + """ + config_args = [ + "", + "src/maxtext/configs/base.yml", + "model_name=deepseek3-tiny", + "decoder_block=deepseek", + "num_decoder_layers=2", + "per_device_batch_size=1", + "max_target_length=16", + "routed_bias=True", + "routed_bias_update_rate=0.001", + "skip_jax_distributed_system=True", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "num_experts=2", + "num_experts_per_tok=2", + "first_num_dense_layers=1", + "sparse_matmul=False", + "override_model_config=True", + ] + new_state, metrics = self._create_and_run_train_step(config_args) + self.assertIsNotNone(new_state) + self.assertIn("learning/loss", metrics["scalar"]) + + def _create_and_run_linen_train_step(self, config_args): + config = pyconfig.initialize(config_args) + model = models.transformer_as_linen(config, self.mesh, quant=None) + data = self._make_dummy_data(batch=config.micro_batch_size_to_train_on, seq=config.max_target_length) + rng = jax.random.PRNGKey(0) + variables = model.init(rng, data["inputs"], data["inputs_position"], data["inputs_segmentation"]) + ts = train_state.TrainState.create( + apply_fn=model.apply, + params=variables["params"], + tx=optax.sgd(0.01) + ) + new_state, metrics = pre_train.train_step( + model, config, state_mesh_shardings=None, params_shardings=None, state=ts, data=data, dropout_rng=jax.random.PRNGKey(0) + ) + return new_state, metrics + + def test_deepseek_v3_moe_routed_bias_linen(self): + """Proves that a DeepSeek V3 model with MoE layers successfully traverses the + Linen state tree and updates routed bias. + """ + config_args = [ + "", + "src/maxtext/configs/base.yml", + "model_name=deepseek3-tiny", + "decoder_block=deepseek", + "num_decoder_layers=2", + "per_device_batch_size=1", + "max_target_length=16", + "routed_bias=True", + "routed_bias_update_rate=0.001", + "skip_jax_distributed_system=True", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "num_experts=2", + "num_experts_per_tok=2", + "first_num_dense_layers=0", + "sparse_matmul=False", + "override_model_config=True", + ] + new_state, metrics = self._create_and_run_linen_train_step(config_args) + self.assertIsNotNone(new_state) + self.assertTrue(any(key.startswith("learning/moe_bias_before_norm") for key in metrics["scalar"])) + self.assertTrue(any(key.startswith("learning/moe_bias_update_norm") for key in metrics["scalar"])) +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/metric_logger_test_coverage.py b/tests/unit/metric_logger_test_coverage.py new file mode 100644 index 0000000000..5567a7492c --- /dev/null +++ b/tests/unit/metric_logger_test_coverage.py @@ -0,0 +1,29 @@ +import unittest +from maxtext.common.metric_logger import MetricLogger +from maxtext.configs import pyconfig +from unittest import mock + +class MetricLoggerTest(unittest.TestCase): + def test_log_train_metrics_moe_lb_loss(self): + config = pyconfig.initialize(["", "src/maxtext/configs/base.yml", "run_name=test_run", "base_output_directory=/tmp/maxtext_output", "num_experts=2", "mtp_num_layers=0", "base_moe_mlp_dim=64", "base_mlp_dim=64"]) + + logger = MetricLogger(config, None) + metrics = { + "scalar": { + "learning/loss": 1.0, + "learning/lm_loss": 1.0, + "learning/total_weights": 1000, + "learning/moe_lb_loss": 0.000403, + "perf/step_time_seconds": 1.0, + "perf/per_device_tflops_per_sec": 1.0, + "perf/per_device_tokens_per_sec": 1.0, + } + } + with mock.patch("maxtext.common.metric_logger.max_logging.log") as mock_log: + logger._log_training_metrics(metrics, 1) + mock_log.assert_called() + called_args = mock_log.call_args[0][0] + self.assertIn("moe_lb_loss: 0.000403", called_args) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index b8eab1061e..4b9fe305eb 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -622,5 +622,44 @@ def __init__(self, rngs: nnx.Rngs): self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) +class TestGetOptimizerGlobalMask(unittest.TestCase): + """Tests that the global optimizer cleanly masks out the routed bias.""" + def test_routed_bias_global_mask(self): + config = pyconfig.initialize(["", "src/maxtext/configs/base.yml", "routed_bias=True", "opt_type=sgd"]) + # We define a dummy params dict containing a routed bias and a regular weight. + # The routed bias must be completely ignored by the optimizer. + params = { + "decoder": { + "moe_layers": { + "MoeBlock_0": { + "gate": { + "bias": jnp.array([1.0]), + "kernel": jnp.array([1.0]) + } + } + } + } + } + grads = { + "decoder": { + "moe_layers": { + "MoeBlock_0": { + "gate": { + "bias": jnp.array([0.5]), + "kernel": jnp.array([0.5]) + } + } + } + } + } + # We use sgd because it's simple to test updates, but the mask logic applies + # cleanly to any base optimizer returned by get_optimizer. + opt = optimizers.get_optimizer(config, learning_rate_schedule=0.1) + opt_state = opt.init(params) + updates, _ = opt.update(grads, opt_state, params) + # The routed bias update should be exactly 0.0 (masked by set_to_zero) + self.assertEqual(updates["decoder"]["moe_layers"]["MoeBlock_0"]["gate"]["bias"].item(), 0.0) + # The kernel should receive the SGD gradient update (-0.1 * 0.5) + self.assertTrue(updates["decoder"]["moe_layers"]["MoeBlock_0"]["gate"]["kernel"].item() < 0.0) if __name__ == "__main__": unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index ebeededbd7..b31bc4a5dc 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -61,8 +61,12 @@ class _Cfg: shard_mode: int = 0 # ShardMode.AUTO weight_sparsity_n: int = 0 weight_sparsity_m: int = 0 + decoder_block: str = "default" +class _DummyDecoder(nnx.Module): + pass + class _TinyDecoder(nnx.Module): """Mimics NNXDecoder.__call__ enough for loss_fn to run end-to-end. @@ -73,6 +77,7 @@ class _TinyDecoder(nnx.Module): def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): self.embed = nnx.Embed(vocab_size, hidden, rngs=rngs) self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) + self.decoder = _DummyDecoder() def __call__( self, @@ -125,7 +130,6 @@ def test_returns_loss_and_full_aux_dict(self): "total_weights", "moe_lb_loss", "indexer_loss", - "moe_bias_updates", "mtp_loss", ): self.assertIn(key, aux) @@ -194,6 +198,18 @@ def test_train_step_with_gradient_clipping(self): self.assertIsInstance(new_state, nnx.State) self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + def test_train_step_deepseek_aux_loss(self): + cfg, ts = _build_state() + cfg.routed_bias = True + cfg.routed_bias_update_rate = 0.001 + cfg.decoder_block = "deepseek" + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + # The robust trainer logic will correctly traverse and NOT crash, ignoring the hardcoded path + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertIsInstance(new_state, nnx.State) class TestEvalStepNNX(unittest.TestCase): """Cover the NNX branch of eval_step (lines 568-570)."""