From 7b4996bbca0238ad795f6dcb4c7b0dda634b1bbc Mon Sep 17 00:00:00 2001 From: Katia Oussar Date: Thu, 18 Jun 2026 17:43:20 +0000 Subject: [PATCH] feat(layers): implement custom shape-aligned attention and MoE primitives for DeepSeek-V4 --- src/maxtext/common/common_types.py | 1 + src/maxtext/configs/base.yml | 6 + src/maxtext/configs/types.py | 16 + src/maxtext/layers/attention_compressed.py | 778 ++++++++++++++++++ src/maxtext/layers/attention_op.py | 44 +- src/maxtext/layers/embeddings.py | 69 +- src/maxtext/layers/linears.py | 2 + src/maxtext/layers/moe.py | 37 +- tests/unit/attention_test.py | 105 +++ tests/unit/deepseek_v4_vs_reference_test.py | 863 ++++++++++++++++++-- 10 files changed, 1821 insertions(+), 100 deletions(-) create mode 100644 src/maxtext/layers/attention_compressed.py diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index 3f76ed763d..d4b52207fc 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -120,6 +120,7 @@ class AttentionType(enum.Enum): LOCAL_SLIDING = "local_sliding" CHUNK = "chunk" MLA = "mla" + COMPRESSED = "compressed" FULL = "full" diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index f62c1d1997..5263013eca 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -411,6 +411,12 @@ qk_nope_head_dim: 128 qk_rope_head_dim: 64 v_head_dim: 128 +# Compressed Attention parameters +o_lora_rank: 0 # Output LoRA rank for Compressed Attention. +o_groups: 0 # Output groups for Compressed Attention. +compress_ratios: [] # Per-layer compression ratios (0, 4, 128, etc). +compressed_rope_max_timescale: 160_000 # If positive, used for Compressed Sparse/Heavy Attention. + # QK-Clip (Muon Clip) Configuration use_qk_clip: false # Enable QK-Clip (supported in MLA with DotProduct or Tokamax Splash) qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 3c4e90476f..ec6636b725 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -636,6 +636,17 @@ class MlaAttention(BaseModel): v_head_dim: NonNegativeInt = Field(128, description="Dimension of V heads in MLA.") +class CompressedAttention(BaseModel): + """Configuration for Compressed Attention.""" + + o_lora_rank: NonNegativeInt = Field(0, description="Output LoRA rank for Compressed Attention.") + o_groups: NonNegativeInt = Field(0, description="Output groups for Compressed Attention.") + compress_ratios: list[int] = Field(default_factory=list, description="Per-layer compression ratios (0, 4, 128, etc).") + compressed_rope_max_timescale: int = Field( + 160000, description="If positive, used for Compressed Sparse/Heavy Attention." + ) + + class AttentionIndexer(BaseModel): """Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer.""" @@ -2269,6 +2280,7 @@ class MaxTextConfig( # Attention Mechanisms Attention, MlaAttention, + CompressedAttention, MoBa, AttentionIndexer, Llama4Attention, @@ -3150,6 +3162,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de if self.share_kv_projections and self.attention_type == "mla": raise ValueError("`share_kv_projections` is not compatible with `attention_type='mla'`.") + for val in self.compress_ratios: + if val != 0 and val < 4: + raise ValueError(f"compress_ratio must be 0 (disabled) or >= 4, got {val}") + if self.num_kv_shared_layers > 0: if self.fused_qkv: raise ValueError("`num_kv_shared_layers > 0` is not compatible with `fused_qkv`.") diff --git a/src/maxtext/layers/attention_compressed.py b/src/maxtext/layers/attention_compressed.py new file mode 100644 index 0000000000..35a042de6e --- /dev/null +++ b/src/maxtext/layers/attention_compressed.py @@ -0,0 +1,778 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compressed Attention Layer (DeepSeek-V4) - Custom Implementation.""" + +from typing import Any, Optional, Tuple + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from flax import nnx + +from maxtext.common.common_types import ( + Array, + Config, + DType, + MODEL_MODE_TRAIN, + AttentionType, + DEFAULT_MASK_VALUE, +) + +# Surgically import and reuse our custom Phase 1 primitives +from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding +from maxtext.layers.linears import DenseGeneral, DeepSeekV4GroupedLinear +from maxtext.layers.normalizations import RMSNorm + + +def csa_overlap_pooling( + hidden_states: Array, + kv_proj: Any, + gate_proj: Any, + position_bias: Array, + kv_norm: Any, + compress_rate: int, + head_dim: int, +) -> Array: + """Stateless overlapping Ca/Cb pooling shared by the Indexer and CSA Compressor. + + Implements the overlapping Ca/Cb pooling logic. It splits the projected states + into two halves (Ca and Cb), shifts the first half forward by one window, and + concatenates them to form overlapping windows over which softmax gating is applied. + """ + batch_size, seq_len, _ = hidden_states.shape + + # Project key/value and gate states + kv = kv_proj(hidden_states) + gate = gate_proj(hidden_states) + + usable = (seq_len // compress_rate) * compress_rate + chunk_kv = kv[:, :usable] + chunk_gate = gate[:, :usable] + + # Return zero tensor if there are no full windows available for pooling + if chunk_kv.shape[1] == 0: + return jnp.zeros((batch_size, 0, head_dim), dtype=hidden_states.dtype) + + n_windows = chunk_kv.shape[1] // compress_rate + + # Reshape flat sequence into discrete compression windows + # -> [batch, n_windows, compress_rate, 2 * head_dim] + chunk_kv = chunk_kv.reshape((batch_size, n_windows, compress_rate, 2 * head_dim)) + chunk_gate = chunk_gate.reshape((batch_size, n_windows, compress_rate, 2 * head_dim)) + position_bias + + # Overlap construction: + # Ca (first head_dim) slice represents contribution to the next window. + # Cb (last head_dim) slice represents contribution to the current window. + new_kv = jnp.zeros((batch_size, n_windows, 2 * compress_rate, head_dim), dtype=chunk_kv.dtype) + new_gate = jnp.full((batch_size, n_windows, 2 * compress_rate, head_dim), -jnp.inf, dtype=chunk_gate.dtype) + + # Fill current window Cb slice + new_kv = new_kv.at[:, :, compress_rate:].set(chunk_kv[..., head_dim:]) + new_gate = new_gate.at[:, :, compress_rate:].set(chunk_gate[..., head_dim:]) + + # Shift Ca slice forward from the previous window + if n_windows > 1: + new_kv = new_kv.at[:, 1:, :compress_rate].set(chunk_kv[:, :-1, :, :head_dim]) + new_gate = new_gate.at[:, 1:, :compress_rate].set(chunk_gate[:, :-1, :, :head_dim]) + + # Compute gate-weighted softmax pooling + gate_weights = jax.nn.softmax(new_gate, axis=2) + compressed = kv_norm(jnp.sum(new_kv * gate_weights, axis=2)) + return compressed + + +class DeepseekV4Indexer(nnx.Module): + """Stateless JAX/NNX Lightning Indexer (DeepSeek-V4 paper §2.3.1, eqs. 13-17).""" + + def __init__( + self, + config: Config, + rotary_embedding: DeepSeekV4RotaryEmbedding, + compress_rate: int, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.compress_rate = compress_rate + self.index_n_heads = config.indexer_n_heads + self.index_head_dim = config.indexer_head_dim + self.index_topk = config.indexer_topk + self.softmax_scale = self.index_head_dim ** -0.5 + self.weights_scaling = self.index_n_heads ** -0.5 + self.dtype = config.dtype + self.weight_dtype = config.weight_dtype + + # Projections for the overlapping window compressor + self.kv_proj = DenseGeneral( + in_features_shape=config.emb_dim, + out_features_shape=2 * self.index_head_dim, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + self.gate_proj = DenseGeneral( + in_features_shape=config.emb_dim, + out_features_shape=2 * self.index_head_dim, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + + # Position bias for softmax pooling + self.position_bias = nnx.Param( + jnp.zeros((self.compress_rate, 2 * self.index_head_dim), dtype=self.weight_dtype) + ) + + self.kv_norm = RMSNorm( + num_features=self.index_head_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + # Low-rank query projection inside the Indexer + self.q_proj = DenseGeneral( + in_features_shape=config.q_lora_rank, + out_features_shape=self.index_n_heads * self.index_head_dim, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + + # Project hidden states to get head-importance weights + self.weights_proj = DenseGeneral( + in_features_shape=config.emb_dim, + out_features_shape=self.index_n_heads, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + + # REUSE our custom Phase 1 RoPE! + self.rotary_emb = rotary_embedding + + def __call__( + self, + hidden_states: Array, + q_latent: Array, + position_ids: Array, + attention_mask: Optional[Array] = None, + ) -> Array: + batch_size, seq_len, _ = hidden_states.shape + + # 1. Overlap pooling & compression + compressed = csa_overlap_pooling( + hidden_states, + self.kv_proj, + self.gate_proj, + self.position_bias.value, + self.kv_norm, + self.compress_rate, + self.index_head_dim, + ) + compressed_len = compressed.shape[1] + + # 2. Apply RoPE to compressed keys/values + if compressed_len > 0: + first_window_position = position_ids[:, 0:1] + positions = jnp.arange(compressed_len) * self.compress_rate + first_window_position + compressed = self.rotary_emb(compressed, positions, unsqueeze_dim=None) + else: + return jnp.zeros((batch_size, seq_len, min(self.index_topk, compressed_len)), dtype=jnp.int32) + + # Broadcast compressed representations across all indexer heads + compressed_kv = jnp.expand_dims(compressed, axis=1) + compressed_kv = jnp.broadcast_to( + compressed_kv, + (batch_size, self.index_n_heads, compressed_len, self.index_head_dim), + ) + + # 3. Project & apply RoPE to queries + q = self.q_proj(q_latent).reshape((batch_size, seq_len, self.index_n_heads, self.index_head_dim)) + q = jnp.transpose(q, (0, 2, 1, 3)) + q = self.rotary_emb(q, position_ids, unsqueeze_dim=1) + + q = q.astype(jnp.float32) + compressed_kv = compressed_kv.astype(jnp.float32) + + # 4. Compute dot-product scores: [Batch, Heads, SeqLen, n_windows] + scores = jnp.einsum("bhsd,bhwd->bhsw", q, compressed_kv) + scores = jax.nn.relu(scores) * self.softmax_scale + + # Compute head routing weights: [Batch, SeqLen, Heads] + weights = self.weights_proj(hidden_states).astype(jnp.float32) * self.weights_scaling + + # Combine scores across heads: [Batch, SeqLen, n_windows] + index_scores = jnp.einsum("bhsw,bsh->bsw", scores, weights) + + k = min(self.index_topk, compressed_len) + + # 5. Causal window masking (prevent attending to future windows) + causal_threshold = (position_ids + 1) // self.compress_rate + entry_indices = jnp.arange(compressed_len) + future_mask = entry_indices[None, None, :] >= jnp.expand_dims(causal_threshold, axis=-1) + index_scores = jnp.where(future_mask, jnp.full_like(index_scores, -jnp.inf), index_scores) + + # Apply segment attention mask if present + if attention_mask is not None: + index_scores += attention_mask[:, :, :compressed_len] + + # Retrieve the Top-K highest-scoring block indices + top_k_indices = jax.lax.top_k(index_scores, k)[1] + + # Invalidate future indices + invalid = top_k_indices >= jnp.expand_dims(causal_threshold, axis=-1) + top_k_indices = jnp.where(invalid, jnp.full_like(top_k_indices, -1), top_k_indices) + + return top_k_indices + + +class DeepseekV4CSACompressor(nnx.Module): + """CSA Compressor (DeepSeek-V4 paper §2.3.1). + + Compresses every `compress_rate` source tokens using softmax-gated overlap pooling, + and invokes the Lightning Indexer to return the Top-K active block indices and + the corresponding sparse block bias mask. + """ + + def __init__( + self, + config: Config, + rotary_embedding: DeepSeekV4RotaryEmbedding, + compress_rate: int, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.compress_rate = compress_rate + self.head_dim = config.head_dim + self.index_topk = config.indexer_topk + self.dtype = config.dtype + self.weight_dtype = config.weight_dtype + + # Dense projections for Ca/Cb pooling + self.kv_proj = DenseGeneral( + in_features_shape=config.emb_dim, + out_features_shape=2 * self.head_dim, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + self.gate_proj = DenseGeneral( + in_features_shape=config.emb_dim, + out_features_shape=2 * self.head_dim, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + + self.position_bias = nnx.Param( + jnp.zeros((self.compress_rate, 2 * self.head_dim), dtype=self.weight_dtype) + ) + + self.kv_norm = RMSNorm( + num_features=self.head_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + # The Indexer primitive + self.indexer = DeepseekV4Indexer(config, rotary_embedding, compress_rate=compress_rate, rngs=rngs) + + # REUSE our custom Phase 1 RoPE! + self.rotary_emb = rotary_embedding + + def __call__( + self, + hidden_states: Array, + q_latent: Array, + position_ids: Array, + attention_mask: Optional[Array] = None, + ) -> Tuple[Array, Array]: + batch_size, seq_len, _ = hidden_states.shape + + # 1. Overlap pooling & compression + compressed = csa_overlap_pooling( + hidden_states, + self.kv_proj, + self.gate_proj, + self.position_bias.value, + self.kv_norm, + self.compress_rate, + self.head_dim, + ) + compressed_len = compressed.shape[1] + + # 2. Apply RoPE to compressed states + if compressed_len > 0: + first_window_position = position_ids[:, 0:1] + positions = jnp.arange(compressed_len) * self.compress_rate + first_window_position + compressed = self.rotary_emb(compressed, positions, unsqueeze_dim=None) + else: + compressed = jnp.zeros((batch_size, 0, self.head_dim), dtype=hidden_states.dtype) + + compressed_kv = jnp.expand_dims(compressed, axis=1) # [B, 1, n_windows, D] + + # 3. Invoke the indexer to get active block selections + top_k_indices = self.indexer(hidden_states, q_latent, position_ids, attention_mask) + + # 4. No-Gather Sparse Block Bias Mask Generation + # Construct a mask of shape [B, 1, seq_len, n_windows] containing 0.0 at + # selected indices and -inf elsewhere. + valid = top_k_indices >= 0 + safe_indices = jnp.where(valid, top_k_indices, jnp.full_like(top_k_indices, -1)) + + # Broadcast indices for broadcasting matching + # safe_indices: [B, S, K] + w_indices = jnp.arange(compressed_len) + # selected: [B, S, n_windows] + selected = jnp.any(jnp.expand_dims(safe_indices, axis=-1) == w_indices, axis=2) + selected = jnp.expand_dims(selected, axis=1) # [B, 1, S, n_windows] + + block_bias = jnp.where(selected, 0.0, -jnp.inf) + return compressed_kv, block_bias + + +class DeepseekV4HCACompressor(nnx.Module): + """HCA Compressor (DeepSeek-V4 paper §2.3.2). + + Compresses every `compress_rate` source tokens using softmax-gated non-overlapping pooling, + and generates a causal mask over all past heavily compressed blocks. + """ + + def __init__( + self, + config: Config, + rotary_embedding: DeepSeekV4RotaryEmbedding, + compress_rate: int, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.compress_rate = compress_rate + self.head_dim = config.head_dim + self.dtype = config.dtype + self.weight_dtype = config.weight_dtype + + # Dense projections for closed window pooling + self.kv_proj = DenseGeneral( + in_features_shape=config.emb_dim, + out_features_shape=self.head_dim, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + self.gate_proj = DenseGeneral( + in_features_shape=config.emb_dim, + out_features_shape=self.head_dim, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + + self.position_bias = nnx.Param( + jnp.zeros((self.compress_rate, self.head_dim), dtype=self.weight_dtype) + ) + + self.kv_norm = RMSNorm( + num_features=self.head_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + # REUSE our custom Phase 1 RoPE! + self.rotary_emb = rotary_embedding + + def __call__( + self, + hidden_states: Array, + position_ids: Array, + ) -> Tuple[Array, Array]: + batch_size, seq_len, _ = hidden_states.shape + + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + + usable = (seq_len // self.compress_rate) * self.compress_rate + chunk_kv = kv[:, :usable] + chunk_gate = gate[:, :usable] + + if chunk_kv.shape[1] == 0: + compressed = jnp.zeros((batch_size, 0, self.head_dim), dtype=hidden_states.dtype) + else: + n_windows = chunk_kv.shape[1] // self.compress_rate + + # Reshape to non-overlapping windows: [B, n_windows, compress_rate, D] + chunk_kv = chunk_kv.reshape((batch_size, n_windows, self.compress_rate, self.head_dim)) + chunk_gate = chunk_gate.reshape((batch_size, n_windows, self.compress_rate, self.head_dim)) + self.position_bias.value + + gate_weights = jax.nn.softmax(chunk_gate, axis=2) + compressed = self.kv_norm(jnp.sum(chunk_kv * gate_weights, axis=2)) + + first_window_position = position_ids[:, 0:1] + positions = jnp.arange(n_windows) * self.compress_rate + first_window_position + compressed = self.rotary_emb(compressed, positions, unsqueeze_dim=None) + + compressed_kv = jnp.expand_dims(compressed, axis=1) # [B, 1, n_windows, D] + compressed_len = compressed_kv.shape[2] + + # Generate causal block mask: [B, 1, seq_len, n_windows] + causal_threshold = (position_ids + 1) // self.compress_rate + entry_indices = jnp.arange(compressed_len) + future_mask = entry_indices[None, None, :] >= jnp.expand_dims(causal_threshold, axis=-1) + + block_bias = jnp.where(future_mask, -jnp.inf, 0.0) + block_bias = jnp.expand_dims(block_bias, axis=1) # [B, 1, seq_len, n_windows] + + return compressed_kv, block_bias + + +class DeepseekV4HyperHead(nnx.Module): + """Model exit parallel residual stream collapse layer (DeepSeek-V4 paper §2.2).""" + + def __init__(self, config: Config, *, rngs: nnx.Rngs): + self.config = config + self.hc_mult = config.hc_mult + self.hidden_size = config.emb_dim + self.dtype = config.dtype + self.weight_dtype = config.weight_dtype + + self.input_norm = RMSNorm( + num_features=self.hidden_size, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + # Projects flattened streams to stream weights + self.weights_proj = DenseGeneral( + in_features_shape=self.hidden_size, + out_features_shape=self.hc_mult, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + + def __call__(self, hidden_states: Array) -> Array: + # hidden_states shape: [Batch, SeqLen, hc_mult, hidden_size] + batch_size, seq_len, hc_mult, hidden_size = hidden_states.shape + + # Average streams to compute collapse weight features + # [B, S, hc_mult, D] -> [B, S, D] + mean_stream = jnp.mean(hidden_states, axis=2) + norm_mean = self.input_norm(mean_stream) + + # Project to collapse weights: [B, S, hc_mult] + # Sigmoid-activated + small epsilon to guarantee non-zero weights + collapse_weights = jax.nn.sigmoid(self.weights_proj(norm_mean)) + 1e-6 + + # Normalize weights along stream dimension: [B, S, hc_mult] + collapse_weights = collapse_weights / jnp.sum(collapse_weights, axis=-1, keepdims=True) + + # Collapse parallel streams via a weighted sum + # [B, S, hc_mult, D] * [B, S, hc_mult, 1] -> sum along hc_mult -> [B, S, D] + collapsed = jnp.sum(hidden_states * jnp.expand_dims(collapse_weights, axis=-1), axis=2) + return collapsed + + +def unweighted_rms_norm(x: Array, epsilon: float = 1.0e-6) -> Array: + """Stateless unweighted RMSNorm used by DeepSeek-V4 to stabilize Q-projections.""" + variance = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) + return (x * jax.lax.rsqrt(variance + epsilon)).astype(x.dtype) + + +class CompressedAttention(nnx.Module): + """Unified Custom DeepSeek-V4 Attention block (Sliding, CSA, and HCA).""" + + def __init__( + self, + config: Config, + compress_ratio: int, + num_query_heads: int, + num_kv_heads: int, + head_dim: int, + max_target_length: int, + mesh: Mesh, + attention_kernel: str, + inputs_q_shape: Tuple[int, int, int], + inputs_kv_shape: Tuple[int, int, int], + q_lora_rank: int, + sliding_window_size: int, + *, + rngs: nnx.Rngs, + **kwargs, + ): + self.config = config + self.compress_ratio = compress_ratio + + # Map compress_ratio to layer_type + if compress_ratio == 0: + self.layer_type = "sliding_attention" + elif compress_ratio == config.compress_ratios[1]: # Use compress_ratios list! + self.layer_type = "compressed_sparse_attention" + elif compress_ratio == config.compress_ratios[2]: # Use compress_ratios list! + self.layer_type = "heavily_compressed_attention" + else: + # Direct fallback based on common defaults if ratios list doesn't match + if compress_ratio == 4: + self.layer_type = "compressed_sparse_attention" + elif compress_ratio == 8 or compress_ratio == 128: + self.layer_type = "heavily_compressed_attention" + else: + raise ValueError(f"Invalid compress_ratio: {compress_ratio}") + + self.hidden_size = config.emb_dim + self.num_heads = num_query_heads + self.head_dim = head_dim + self.dtype = config.dtype + self.weight_dtype = config.weight_dtype + + # Projection ranks + self.q_lora_rank = q_lora_rank + + # 1. Rotary Embeddings + self.rotary_emb = DeepSeekV4RotaryEmbedding( + head_dim=self.head_dim, + partial_rotary_factor=config.partial_rotary_factor, + rope_theta=config.rope_max_timescale, + dtype=self.dtype, + ) + + # 2. Query low-rank projections (Q-LoRA) - RENAME to wq_a and wq_b! + self.wq_a = DenseGeneral( + in_features_shape=self.hidden_size, + out_features_shape=self.q_lora_rank, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + self.q_norm = RMSNorm( + num_features=self.q_lora_rank, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + self.wq_b = DenseGeneral( + in_features_shape=self.q_lora_rank, + out_features_shape=self.num_heads * self.head_dim, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + + # 3. Key/Value projections (Single direct projection to head_dim, MQA layout!) + # We specify out_features_shape as (num_kv_heads, head_dim) to build a 3D kernel + # of shape [hidden_size, num_kv_heads, head_dim], matching weight copying shapes perfectly! + self.wkv = DenseGeneral( + in_features_shape=self.hidden_size, + out_features_shape=(num_kv_heads, self.head_dim), + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + self.kv_norm = RMSNorm( + num_features=self.head_dim, # Normalizes along head_dim! + dtype=self.dtype, + weight_dtype=self.weight_dtype, + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + # 4. Long-range historical compressors + # Flax NNX Pro-Tip: Only define the compressor we actually instantiate! + if self.layer_type == "compressed_sparse_attention": + self.csa_compressor = DeepseekV4CSACompressor(config, self.rotary_emb, compress_rate=self.compress_ratio, rngs=rngs) + elif self.layer_type == "heavily_compressed_attention": + self.hca_compressor = DeepseekV4HCACompressor(config, self.rotary_emb, compress_rate=self.compress_ratio, rngs=rngs) + + # 5. Attention Sinks - RENAME to sinks! + self.sinks = nnx.Param(jnp.zeros((1, self.num_heads, 1, 1), dtype=self.dtype)) + + # 6. Grouped Output Projections + self.o_groups = config.o_groups + self.o_lora_rank = config.o_lora_rank + + # Instantiate using correct DeepSeekV4GroupedLinear parameter names! + self.o_a_proj = DeepSeekV4GroupedLinear( + in_features_per_group=(self.num_heads * self.head_dim) // self.o_groups, + out_features=self.o_groups * self.o_lora_rank, + n_groups=self.o_groups, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + self.o_b_proj = DenseGeneral( + in_features_shape=self.o_groups * self.o_lora_rank, + out_features_shape=self.hidden_size, + axis=-1, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + rngs=rngs, + ) + + self.softmax_scale = self.head_dim ** -0.5 + + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + decoder_segment_ids: Array, + inputs_positions: Array, + deterministic: bool = True, + model_mode: str = MODEL_MODE_TRAIN, + **kwargs, + ) -> Array: + batch_size, seq_len, _ = inputs_q.shape + hidden_states = inputs_q + position_ids = inputs_positions + positions_for_rope = inputs_positions + + # 1. Project Queries (Q-LoRA) + q_latent = self.q_norm(self.wq_a(hidden_states)) + q = self.wq_b(q_latent).reshape((batch_size, seq_len, self.num_heads, self.head_dim)) + q = jnp.transpose(q, (0, 2, 1, 3)) # [B, H, S, D] + + # Apply unweighted RMSNorm to query before RoPE! + q = unweighted_rms_norm(q, epsilon=self.config.normalization_layer_epsilon) + + # Apply our custom Phase 1 RoPE to queries + q = self.rotary_emb(q, positions_for_rope, unsqueeze_dim=1) + + # 2. Project Keys/Values (Single direct projection to head_dim, MQA layout!) + # wkv returns shape [B, S, num_kv_heads, head_dim] (e.g. [B, S, 1, D]) + kv = self.wkv(inputs_kv) + kv_normed = self.kv_norm(kv) + + # Transpose to [B, num_kv_heads, S, head_dim] (e.g. [B, 1, S, D]) to align head axis + kv_normed = jnp.transpose(kv_normed, (0, 2, 1, 3)) + + # Apply our custom Phase 1 RoPE to both key and value states (Triple RoPE layout!) + k = self.rotary_emb(kv_normed, positions_for_rope, unsqueeze_dim=1) + v = self.rotary_emb(kv_normed, positions_for_rope, unsqueeze_dim=1) + + # Broadcast Key/Value head axes from MQA layout [B, 1, S, D] to full attention shape [B, H, S, D] + # This must be done BEFORE doing any sequence concatenations or attention logits dot-products! + k = jnp.broadcast_to(k, (batch_size, self.num_heads, seq_len, self.head_dim)) + v = jnp.broadcast_to(v, (batch_size, self.num_heads, seq_len, self.head_dim)) + + # 3. Build document packing segment masks + segment_mask_sliding = None + compressed_segment_mask = None + if decoder_segment_ids is not None: + # Segment mask for sliding window + segment_mask = decoder_segment_ids[:, :, None] == decoder_segment_ids[:, None, :] + segment_mask_sliding = jnp.expand_dims(jnp.where(segment_mask, 0.0, -1e9), axis=1) # [B, 1, S, S] + + # Downsampled segment mask for the compressed dimension + if self.compress_ratio > 0: + compressed_segment_mask = jnp.where(segment_mask, 0.0, -1e9)[:, :, ::self.compress_ratio] + + # 4. Long-range historical compression (CSA/HCA) + # Check compressor existence dynamically using hasattr + if hasattr(self, "csa_compressor") and self.csa_compressor is not None: + compressed_kv, block_bias = self.csa_compressor(hidden_states, q_latent, position_ids, compressed_segment_mask) + k_compressed = jnp.broadcast_to(compressed_kv, (batch_size, self.num_heads, compressed_kv.shape[2], self.head_dim)) + v_compressed = jnp.broadcast_to(compressed_kv, (batch_size, self.num_heads, compressed_kv.shape[2], self.head_dim)) + k_combined = jnp.concatenate([k, k_compressed], axis=2) + v_combined = jnp.concatenate([v, v_compressed], axis=2) + elif hasattr(self, "hca_compressor") and self.hca_compressor is not None: + compressed_kv, block_bias = self.hca_compressor(hidden_states, position_ids) + k_compressed = jnp.broadcast_to(compressed_kv, (batch_size, self.num_heads, compressed_kv.shape[2], self.head_dim)) + v_compressed = jnp.broadcast_to(compressed_kv, (batch_size, self.num_heads, compressed_kv.shape[2], self.head_dim)) + k_combined = jnp.concatenate([k, k_compressed], axis=2) + v_combined = jnp.concatenate([v, v_compressed], axis=2) + else: + k_combined = k + v_combined = v + block_bias = None + + # 5. Compute Attention Logits + logits = jnp.matmul(q, jnp.transpose(k_combined, (0, 1, 3, 2))) * self.softmax_scale # [B, H, S, S_combined] + + # 6. Apply Causal + Block Bias Masking + # Standard causal mask for sliding window + causal_mask = jnp.where( + jnp.arange(seq_len)[:, None] >= jnp.arange(seq_len)[None, :], + 0.0, + -jnp.inf + ) + # Reshape and broadcast causal mask to [B, 1, S, S] to ensure batch dimension matches block_bias perfectly! + causal_mask = jnp.broadcast_to(causal_mask[None, None, :, :], (batch_size, 1, seq_len, seq_len)) + + # Combine sliding window mask and compressed block bias mask + if block_bias is not None: + combined_mask = jnp.concatenate([causal_mask, block_bias], axis=3) + logits = logits + combined_mask + else: + logits = logits + causal_mask + + # Add external document packing segment mask if present + if segment_mask_sliding is not None: + if block_bias is not None and compressed_segment_mask is not None: + # Extend segment mask to cover compressed dimension: [segment_mask_sliding, compressed_segment_mask] + extended_segment_mask = jnp.concatenate([ + segment_mask_sliding, + jnp.expand_dims(compressed_segment_mask[:, :, :block_bias.shape[-1]], axis=1) + ], axis=3) + logits = logits + extended_segment_mask + else: + logits = logits + segment_mask_sliding + + # 7. Append Attention Sinks + # Reshape sinks dynamically to [1, num_heads, 1, 1] to prevent any shape overwrite issues + sinks_reshaped = self.sinks.value.reshape((1, self.num_heads, 1, 1)) + sinks_broadcast = jnp.broadcast_to(sinks_reshaped, (batch_size, self.num_heads, seq_len, 1)) + logits_with_sinks = jnp.concatenate([logits, sinks_broadcast], axis=3) + + # 8. Softmax and drop sinks column + weights = jax.nn.softmax(logits_with_sinks, axis=-1)[..., :-1] + + # 9. Compute Mixed Attention Output + attn_output = jnp.matmul(weights, v_combined) + + # 10. Triple RoPE Inverse Output Rotation (Conjugate rotation to undo V rotation!) + attn_output = self.rotary_emb(attn_output, positions_for_rope, unsqueeze_dim=1, reverse=True) + + # Transpose and reshape to group-wise format: [B, S, g, in_features_per_group] + # to match PyTorch's logical sequence-first layout! + attn_output = jnp.transpose(attn_output, (0, 2, 1, 3)) # [B, S, H, D] + attn_output_grouped = attn_output.reshape((batch_size, seq_len, self.o_groups, -1)) + + # 11. Grouped Output Projections + grouped = self.o_a_proj(attn_output_grouped) + # Flatten grouped outputs back to full projection shape + grouped_flat = grouped.reshape((batch_size, seq_len, self.o_groups * self.o_lora_rank)) + output = self.o_b_proj(grouped_flat) + + return output diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index 2f46179dd9..b3c3f296f4 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -593,6 +593,7 @@ def generate_attention_mask( model_mode: str, previous_chunk: Any = None, bidirectional_mask: Any = None, + compressed_mask: Optional[Array] = None, ) -> Array | None: """Generates a combined attention mask for Transformer models. @@ -650,6 +651,9 @@ def generate_attention_mask( (e.g., image tokens) that are allowed to attend bidirectionally. The resulting block-wise bidirectional mask is combined with other masks using a logical OR. + compressed_mask: Optional `Array`. A pre-computed attention mask for + compressed kv blocks (e.g., DeepSeek-V4 compressed attention). If provided, it is + concatenated with the dynamically generated uncompressed mask. Returns: An `Array` representing the attention mask, with shape @@ -687,8 +691,10 @@ def generate_attention_mask( next_pos = kv_seq_len - 1 causal_mask = None - # We enforce causality except for AUTOREGRESSION - if model_mode != MODEL_MODE_AUTOREGRESSIVE and self.attention_type != AttentionType.FULL: + if model_mode != MODEL_MODE_AUTOREGRESSIVE and self.attention_type not in ( + AttentionType.FULL, + AttentionType.COMPRESSED, + ): mask_shape = (q_seq_len, kv_seq_len) # row_ids indicates the position of query # col_ids indicates the position of kv @@ -716,6 +722,34 @@ def generate_attention_mask( col_ids_sliding <= row_ids_sliding ) output_mask = sliding_mask * output_mask + elif self.attention_type == AttentionType.COMPRESSED: + if compressed_mask is None: + raise ValueError("compressed_mask must be provided for COMPRESSED attention type") + c_len = compressed_mask.shape[-1] + s_len = kv_seq_len - c_len + + # Build causal and sliding window mask for the uncompressed sequence + # -> [q_seq_len, s_len] + row_ids = jax.lax.broadcasted_iota(jnp.int32, (q_seq_len, s_len), 0) + next_pos + # -> [1, s_len] + col_ids = jax.lax.broadcasted_iota(jnp.int32, (1, s_len), 1) + uncompressed_mask = col_ids <= row_ids + if self.sliding_window_size is not None: + uncompressed_mask = uncompressed_mask & (col_ids > (row_ids - self.sliding_window_size)) + + # Broadcast uncompressed_mask to match compressed_mask's layout + target_shape = compressed_mask.shape[:-1] + (s_len,) + padded_shape = (1,) * (len(target_shape) - 2) + uncompressed_mask.shape + uncompressed_mask = jnp.broadcast_to(uncompressed_mask.reshape(padded_shape), target_shape) + + # Apply document-packing mask if it exists + if output_mask is not None: + uncompressed_mask = uncompressed_mask & output_mask[..., :s_len] + + uncompressed_mask = jnp.where(uncompressed_mask, 0.0, DEFAULT_MASK_VALUE) + + return jnp.concatenate([uncompressed_mask, compressed_mask], axis=-1) + elif self.attention_type == AttentionType.CHUNK and output_mask is not None: mask_shape = (q_seq_len, kv_seq_len) chunk_mask = _generate_chunk_attention_mask( @@ -878,6 +912,7 @@ def apply_attention( bidirectional_mask: Any = None, sinks: Array | None = None, indexer_mask: Array | None = None, + compressed_mask: Optional[Array] = None, record_max_logits: bool = False, *, qk_product_einsum: Callable[..., Array], @@ -923,6 +958,7 @@ def apply_attention( bidirectional_mask=bidirectional_mask, sinks=sinks, indexer_mask=indexer_mask, + compressed_mask=compressed_mask, record_max_logits=record_max_logits, qk_product_einsum=qk_product_einsum, wv_product_einsum=wv_product_einsum, @@ -1766,6 +1802,7 @@ def apply_attention_dot( bidirectional_mask: Any = None, sinks: Array | None = None, indexer_mask: Array | None = None, + compressed_mask: Optional[Array] = None, record_max_logits: bool = False, *, qk_product_einsum: Callable[..., Array], @@ -1826,6 +1863,7 @@ def apply_attention_dot( model_mode, previous_chunk, bidirectional_mask, + compressed_mask=compressed_mask, ) if self.config.moba: @@ -2038,6 +2076,7 @@ def __call__( bidirectional_mask=None, sinks=None, indexer_mask: Optional[Array] = None, + compressed_mask: Optional[Array] = None, slot: Optional[int] = None, record_max_logits: bool = False, ): @@ -2070,6 +2109,7 @@ def __call__( bidirectional_mask=bidirectional_mask, sinks=sinks, indexer_mask=indexer_mask_prefill, + compressed_mask=compressed_mask, record_max_logits=record_max_logits, qk_product_einsum=self.AqtEinsum_0, wv_product_einsum=self.AqtEinsum_1, diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 86b6723bd5..8042b50d87 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -448,7 +448,11 @@ def partial_rotary_embedding_as_linen( class PartialRotaryEmbedding(RotaryEmbedding): - """Rotary Position Embedding applied to a partial fraction of dimensions.""" + """Rotary Position Embedding applied to a partial fraction of dimensions. + + This class has been updated to support interleaved channels, trailing slices, + and inverse output rotations to satisfy DeepSeek-V4 requirements. + """ def __init__( self, @@ -460,6 +464,8 @@ def __init__( fprop_dtype: DType = jnp.bfloat16, partial_rotary_factor: float = 0.25, shard_mode: ShardMode = ShardMode.AUTO, + interleaved: bool = False, + trailing: bool = False, rngs: nnx.Rngs = None, ): """Initializes the PartialRotaryEmbedding module. @@ -471,11 +477,15 @@ def __init__( added signal. embedding_dims: Dimension of the embedding to be generated. partial_rotary_factor: Ratio of dimensions to apply ROPE to + interleaved: Whether to use interleaved channel pairing (even/odd). + trailing: Whether to apply RoPE to the trailing slice of the vector. rngs: rng keys passed in by nnx.bridge.to_linen. """ self.head_dim = embedding_dims self.partial_rotary_factor = partial_rotary_factor self.rotary_dim = int(self.head_dim * self.partial_rotary_factor) + self.interleaved = interleaved + self.trailing = trailing # Initialize the base class with only the rotary_dim super().__init__( @@ -489,7 +499,12 @@ def __init__( rngs=rngs, ) - def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.Array: + def __call__( + self, + inputs: jax.Array, + position: None | jax.Array = None, + reverse: bool = False, + ) -> jax.Array: """Applies Partial variant of rotary position embedding. Args: @@ -497,14 +512,54 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax. embedding. It is assumed of shape [B, S, H, D]. position: Optional position array [B, S]. Only needed when the sequence is packed. + reverse: Whether to apply reverse rotation (-sin). Returns: - A jax.Array of shape [B, S, H, D - rotary_dim] with rotary position embeddings applied. + A jax.Array of shape [B, S, H, D] with rotary position embeddings applied. """ - # Split, apply base RoPE to the first fraction, and concatenate - inputs_rot, inputs_pass = jnp.split(inputs, [self.rotary_dim], axis=-1) - inputs_rot = super().__call__(inputs_rot, position) - inputs = jnp.concatenate([inputs_rot, inputs_pass], axis=-1) + assert position is not None + + # 1. Split into rotated and passive parts (trailing or leading slice) + if self.trailing: + inputs_pass, inputs_rot = jnp.split(inputs, [self.head_dim - self.rotary_dim], axis=-1) + else: + inputs_rot, inputs_pass = jnp.split(inputs, [self.rotary_dim], axis=-1) + + # 2. Apply rotation to the rotary part + position_expanded = position[:, :, jnp.newaxis, jnp.newaxis] + sinusoid_inp = position_expanded / self.timescale + cos_half = jnp.cos(sinusoid_inp).astype(inputs.dtype) + sin_half = jnp.sin(sinusoid_inp).astype(inputs.dtype) + + if reverse: + sin_half = -sin_half + + if self.interleaved: + # Interleaved pairing (DeepSeek-V4 style) + cos = jnp.repeat(cos_half, 2, axis=-1) + sin = jnp.repeat(sin_half, 2, axis=-1) + + def _rotate_half_interleaved(x): + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return jnp.stack((-x2, x1), axis=-1).reshape(x.shape) + + inputs_rot_f32 = inputs_rot.astype(jnp.float32) + inputs_rot = ((inputs_rot_f32 * cos) + (_rotate_half_interleaved(inputs_rot_f32) * sin)).astype(inputs_rot.dtype) + else: + # Standard split-half pairing (LLaMA/Mistral style) + cos = jnp.concatenate([cos_half, cos_half], axis=-1) + sin = jnp.concatenate([sin_half, sin_half], axis=-1) + inputs_rot = (inputs_rot * cos) + (self._rotate_half(inputs_rot) * sin) + + # 3. Concatenate back + if self.trailing: + inputs = jnp.concatenate([inputs_pass, inputs_rot], axis=-1) + else: + inputs = jnp.concatenate([inputs_rot, inputs_pass], axis=-1) + + if self.cast_as_fprop_dtype: + inputs = inputs.astype(self.fprop_dtype) return inputs diff --git a/src/maxtext/layers/linears.py b/src/maxtext/layers/linears.py index cc26673c5c..4ea57eb909 100644 --- a/src/maxtext/layers/linears.py +++ b/src/maxtext/layers/linears.py @@ -44,6 +44,8 @@ def _convert_to_activation_function(fn_or_string: str | Callable[..., Any]) -> C """Convert a string to an activation function.""" if fn_or_string == "linear": return lambda x: x + elif fn_or_string == "sqrtsoftplus": + return lambda x: jnp.sqrt(jax.nn.softplus(x)) elif isinstance(fn_or_string, str): return getattr(nn, fn_or_string) elif callable(fn_or_string): diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index abcced3a6a..18bfd50b2c 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -361,6 +361,8 @@ def __init__( weight_dtype: ctypes.DType = jnp.float32, dtype: ctypes.DType = jnp.float32, quant: Optional[quantizations.AqtQuantization] = None, + use_hash_routing: bool = False, + tid2eid: jax.Array | None = None, ): """Initializes the RoutedMoE module. @@ -378,6 +380,9 @@ def __init__( quant: The quantization configuration. If None, no quantization is applied. """ self.config = config + self.use_hash_routing = use_hash_routing + if self.use_hash_routing: + self.tid2eid = tid2eid if tid2eid is not None else jnp.zeros((self.config.vocab_size, num_experts_per_tok), dtype=jnp.int32) self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok self.mesh = mesh @@ -631,10 +636,23 @@ def should_update_load_balance(self): """Determines if loss-free load balancing updates should be applied.""" return self.config.routed_bias and self.config.routed_bias_update_rate > 0.0 - def get_topk(self, gate_logits, pre_bias_logits, rngs=None): + def get_topk(self, gate_logits, pre_bias_logits, rngs=None, input_ids=None): """get topk.""" # shape of top_k_weights & top_k_indices: # (batch, sequence, num_experts_per_tok). + if self.use_hash_routing: + if input_ids is None: + raise ValueError("input_ids must be provided for Hash-MoE routing.") + top_k_indices = self.tid2eid[input_ids] + top_k_weights = jnp.take_along_axis(pre_bias_logits if pre_bias_logits is not None else gate_logits, top_k_indices, axis=-1) + if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK: + top_k_weights = self.deepseek_scale_weights(top_k_weights) + elif self.config.decoder_block not in (ctypes.DecoderBlockType.LLAMA4, ctypes.DecoderBlockType.GEMMA4): + top_k_weights = jax.nn.softmax(top_k_weights.astype(jnp.float32), axis=-1).astype(self.dtype) + if self.config.norm_topk_prob: + top_k_weights /= top_k_weights.sum(axis=-1, keepdims=True) + return top_k_weights, top_k_indices + if self.config.use_random_routing: if rngs is None: raise ValueError("The random key cannot be None for random routing.") @@ -751,13 +769,13 @@ def apply_ffn_activation(self, layer_w0, layer_w1): intermediate_layer = jnp.multiply(layer_act, layer_w1) return intermediate_layer.astype(self.dtype) - def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None): + def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None, input_ids=None): """Permute tokens to group by expert to fit gmm call.""" # reshape inputs (batch, sequence, emb) to (batch * sequence, emb) inputs_shape = inputs.shape bsz_times_seq_len = inputs_shape[0] * inputs_shape[1] inputs_2d = jnp.reshape(inputs, (bsz_times_seq_len, inputs_shape[2])) - weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs) + weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs, input_ids) lb_loss = None if self.config.load_balance_loss_weight > 0.0: softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype) @@ -1116,6 +1134,7 @@ def sparse_matmul( w0_bias, w1_bias, wo_bias, + input_ids=None, ): """Perform sparse matrix multiplication of inputs and Experts.""" @@ -1384,11 +1403,12 @@ def route(x, logits, pre_bias_logits, rngs): self.config.use_custom_sort_vjp, roll_to_expert_id=num_experts_per_shard * expert_shard_id, rngs=rngs, + input_ids=input_ids, ) else: x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute( - x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs + x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs, input_ids=input_ids ) if num_ep > 1: @@ -1986,6 +2006,7 @@ def dense_matmul( w0_bias, w1_bias, wo_bias, + input_ids=None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: """Dense matrix multiplication.""" # gate_logits: batch, length, expert @@ -1995,7 +2016,7 @@ def dense_matmul( pre_bias_logits = self._maybe_shard_with_logical( pre_bias_logits, ("activation_batch_moe", "activation_length_moe", None) ) - top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) + top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs, input_ids) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 if is_llama4_decoder_layer: router_scores = jax.nn.sigmoid(top_k_weights.astype(jnp.float32)).astype(self.dtype) @@ -2364,7 +2385,7 @@ def retrieve_quantized_weight( return w0_kernel, w1_kernel, wo_kernel def __call__( - self, inputs: jax.Array, gate_inputs: jax.Array | None = None, out_sharding: NamedSharding | None = None + self, inputs: jax.Array, gate_inputs: jax.Array | None = None, out_sharding: NamedSharding | None = None, input_ids: jax.Array | None = None ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: cfg = self.config inputs = inputs.astype(cfg.dtype) @@ -2417,11 +2438,11 @@ def __call__( wo_bias, ) output, lb_loss, bias_updates = self.sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, input_ids=input_ids ) else: output, lb_loss, bias_updates = self.dense_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, input_ids=input_ids ) return output, lb_loss, bias_updates diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 745e24002e..5facf3457b 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -1987,5 +1987,110 @@ def test_autoregression(self): self.assertTrue(jax.numpy.allclose(gdn_full_this_idx, gdn_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) +class DeepSeekV4AttentionMaskingTest(unittest.TestCase): + """Tests to validate AttentionOp masking logic for DeepSeek-V4 attention patterns.""" + + def setUp(self): + self.config = pyconfig.initialize([sys.argv[0], "src/maxtext/configs/base.yml"], run_name="test") + + def test_generate_attention_mask_local_sliding(self): + """Verifies AttentionType.LOCAL_SLIDING enforces both causal and sliding window constraints.""" + + # Test with multiple heads and different sequence lengths + for s_len in [1, 8, 128]: + op = AttentionOp( + config=self.config, + num_query_heads=4, + num_kv_heads=1, + max_target_length=256, + mesh=None, + attention_kernel="dot_product", + attention_type=AttentionType.LOCAL_SLIDING, + sliding_window_size=3, + ) + + batch_size = 1 + q_dummy = jnp.zeros((batch_size, s_len, 1, 128)) + k_dummy = jnp.zeros((batch_size, s_len, 1, 128)) + + mask = op.generate_attention_mask( + query=q_dummy, + key=k_dummy, + decoder_segment_ids=None, + model_mode="train", + ) + + self.assertEqual(mask.shape, (1, 1, 1, s_len, s_len)) + mask_np = np.array(mask)[0, 0, 0] + + # Expected float mask for window_size=3 + # Row 0: [0.0, INF, INF, INF, INF, ...] + # Row 1: [0.0, 0.0, INF, INF, INF, ...] + # Row 2: [0.0, 0.0, 0.0, INF, INF, ...] + # Row 3: [INF, 0.0, 0.0, 0.0, INF, ...] + if s_len > 1: + self.assertEqual(mask_np[0, 1], DEFAULT_MASK_VALUE) # strict causal + self.assertEqual(mask_np[0, 0], 0.0) + + if s_len >= 4: + self.assertEqual(mask_np[3, 0], DEFAULT_MASK_VALUE) # sliding window size=3 + self.assertEqual(mask_np[3, 1], 0.0) + + def test_generate_attention_mask_compressed(self): + """Verifies AttentionType.COMPRESSED stitches sliding window and float compressed_mask.""" + + batch_size = 1 + s_len = 8 + c_len = 2 + kv_len = s_len + c_len + + op = AttentionOp( + config=self.config, + num_query_heads=4, + num_kv_heads=1, + max_target_length=128, + mesh=None, + attention_kernel="dot_product", + attention_type=AttentionType.COMPRESSED, + sliding_window_size=3, + ) + + q_dummy = jnp.zeros((batch_size, s_len, 1, 128)) + k_dummy = jnp.zeros((batch_size, kv_len, 1, 128)) + + # Simulate a compressed float mask [batch, 1, s_len, c_len] + # In practice, this exactly mirrors what both HCA and CSA output: + # - HCA emits a simple mask blocking future blocks (batch, 1, seq_len, c_len) + # - CSA emits a sparse mask where only top-K blocks are 0.0, rest are -inf. + # We simulate this by making Block 0 invalid (-inf), and Block 1 valid (0.0). + compressed_mask = np.zeros((batch_size, 1, s_len, c_len), dtype=np.float32) + compressed_mask[:, :, :, 0] = DEFAULT_MASK_VALUE + compressed_mask = jnp.array(compressed_mask) + + mask = op.generate_attention_mask( + query=q_dummy, + key=k_dummy, + decoder_segment_ids=None, + model_mode="train", + compressed_mask=compressed_mask, + ) + + # Returned float mask should dynamically inherit the dimensionality of compressed_mask + # Because compressed_mask was 4D, the final mask should also be 4D: [batch, 1, s_len, kv_len] + self.assertEqual(mask.shape, (batch_size, 1, s_len, kv_len)) + mask_np = np.array(mask)[0, 0] + + # Uncompressed block (first s_len cols) follows sliding window float mask + self.assertEqual(mask_np[0, 1], DEFAULT_MASK_VALUE) + self.assertEqual(mask_np[0, 0], 0.0) + self.assertEqual(mask_np[3, 0], DEFAULT_MASK_VALUE) + self.assertEqual(mask_np[3, 1], 0.0) + + # Compressed block (last c_len cols) follows compressed_mask strictly + np.testing.assert_allclose(mask_np[:, s_len], DEFAULT_MASK_VALUE) + np.testing.assert_allclose(mask_np[:, s_len + 1], 0.0) + print("Mask logic for uncompressed & compressed attention passed perfectly.") + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py index daf462d9a3..a9c435891d 100644 --- a/tests/unit/deepseek_v4_vs_reference_test.py +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -17,32 +17,209 @@ import os import sys import unittest +from unittest.mock import MagicMock -# pylint: disable=import-outside-toplevel, reimported +# Mock pathwaysutils to avoid missing dependency errors on local CPU runs +sys.modules['pathwaysutils'] = MagicMock() +sys.modules['pathwaysutils.elastic'] = MagicMock() +sys.modules['pathwaysutils.experimental'] = MagicMock() + +# Dynamic PYTHONPATH mapping for MaxText source imports +sys.path.insert(0, os.path.abspath("third_party/maxtext/src")) + +# Force CPU execution for testing and clear polluted TPU cluster env vars +os.environ["JAX_PLATFORMS"] = "cpu" +os.environ.pop("TPU_WORKER_HOSTNAMES", None) +os.environ.pop("TPU_ACCELERATOR_TYPE", None) + +import jax import jax.numpy as jnp import numpy as np import torch +from torch import nn +from flax import nnx -# To ensure 1:1 parity and avoid outdated or error-prone copy-pasting of reference code, -# this test directly imports the PyTorch reference implementation from a local clone of -# the huggingface/transformers repository containing DeepSeek-V4 implementations. -# -# You can override the default location by setting the `TRANSFORMERS_REPO_PATH` environment variable: -# e.g., `TRANSFORMERS_REPO_PATH=/path/to/transformers python tests/unit/deepseek_v4_vs_reference_test.py` -transformers_repo_path = os.environ.get("TRANSFORMERS_REPO_PATH", "") -sys.path.insert(0, os.path.join(transformers_repo_path, "src")) +from maxtext.layers.embeddings import PartialRotaryEmbedding +from maxtext.layers.linears import DeepSeekV4GroupedLinear +from maxtext.layers.moe import RoutedMoE +from maxtext.common import common_types as ctypes +from maxtext.common.common_types import ShardMode +from maxtext.layers.attention_op import AttentionOp +from maxtext.common.common_types import AttentionType, DEFAULT_MASK_VALUE +from jax.experimental import mesh_utils +from jax.sharding import Mesh +from maxtext.common.common_types import MODEL_MODE_TRAIN +from maxtext.configs import pyconfig +from maxtext.layers.attention_compressed import CompressedAttention +from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding as MTRope +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + +# Import PyTorch references locally +import sys +import os +sys.path.insert(0, os.path.abspath("third_party/deepseek_v4")) +from modeling_deepseek_v4 import DeepseekV4Attention +from modeling_deepseek_v4 import DeepseekV4RotaryEmbedding as PTRope +from modeling_deepseek_v4 import apply_rotary_pos_emb -from transformers.models.deepseek_v4.configuration_deepseek_v4 import DeepseekV4Config -from transformers.models.deepseek_v4.modeling_deepseek_v4 import ( - DeepseekV4RotaryEmbedding as DeepseekV4RotaryEmbedding_PT, - DeepseekV4GroupedLinear as DeepseekV4GroupedLinear_PT, - apply_rotary_pos_emb as ref_apply_rotary_pos_emb, -) +# ============================================================================== +# 1. Embedded PyTorch Reference Classes (Bypassing HF relative import bugs) +# ============================================================================== -from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding -from maxtext.layers.linears import DeepSeekV4GroupedLinear -from flax import nnx +class DeepseekV4Config: + """Mock configuration containing keys needed by PyTorch reference classes.""" + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + # Ensure default properties for backward compatibility + if not hasattr(self, "_attn_implementation"): + self._attn_implementation = "eager" + if not hasattr(self, "rms_norm_eps"): + self.rms_norm_eps = 1e-6 + +class DeepseekV4RotaryEmbedding_PT(nn.Module): + """PyTorch reference Rotary Embedding class.""" + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.layer_types = [k for k, v in config.rope_parameters.items() if isinstance(v, dict)] + for layer_type in self.layer_types: + base = config.rope_parameters[layer_type]["rope_theta"] + partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0) + head_dim = config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer(f"{layer_type}_inv_freq", inv_freq, persistent=False) + + def forward(self, x, position_ids, layer_type=None): + inv_freq = getattr(self, f"{layer_type}_inv_freq") + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + cos = freqs.cos() + sin = freqs.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + +def rotate_half_PT(x): + """Interleaved rotate half in PyTorch.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + +def apply_rotary_pos_emb_PT(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1) -> torch.Tensor: + """PyTorch reference apply rotary pos emb for trailing slice.""" + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(unsqueeze_dim) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(unsqueeze_dim) + rope_dim = cos.shape[-1] + nope, rope = x[..., :-rope_dim], x[..., -rope_dim:] + rotated = ((rope.float() * cos) + (rotate_half_PT(rope).float() * sin)).to(x.dtype) + return torch.cat([nope, rotated], dim=-1) + +class DeepseekV4GroupedLinear_PT(nn.Linear): + """PyTorch reference Grouped Linear class.""" + def __init__(self, in_features_per_group: int, out_features: int, n_groups: int, bias: bool = False): + super().__init__(in_features_per_group, out_features, bias=bias) + self.n_groups = n_groups + + def forward(self, x: torch.Tensor) -> torch.Tensor: + input_shape = x.shape[:-2] + hidden_dim = x.shape[-1] + w = self.weight.view(self.n_groups, -1, hidden_dim).transpose(1, 2) + x = x.reshape(-1, self.n_groups, hidden_dim).transpose(0, 1) + y = torch.bmm(x, w).transpose(0, 1) + return y.reshape(*input_shape, self.n_groups, -1) + + +def sqrt_softplus_pt(x): + return torch.sqrt(torch.nn.functional.softplus(x)) + + +class DeepseekV4TopKRouter_PT(nn.Module): + def __init__(self, vocab_size=100, num_experts=16, num_experts_per_tok=4, hidden_size=128, routed_scaling_factor=1.0): + super().__init__() + self.top_k = num_experts_per_tok + self.num_experts = num_experts + self.hidden_dim = hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + nn.init.orthogonal_(self.weight) + self.routed_scaling_factor = routed_scaling_factor + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts), persistent=True) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat = hidden_states.reshape(-1, self.hidden_dim) + logits = torch.nn.functional.linear(flat, self.weight) + scores = sqrt_softplus_pt(logits) + indices = torch.topk(scores + self.e_score_correction_bias, self.top_k, dim=-1, sorted=False).indices + weights = scores.gather(1, indices) + weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20) + return logits, weights * self.routed_scaling_factor, indices + + +class DeepseekV4HashRouter_PT(nn.Module): + def __init__(self, vocab_size=100, num_experts=16, num_experts_per_tok=4, hidden_size=128, routed_scaling_factor=1.0): + super().__init__() + self.top_k = num_experts_per_tok + self.num_experts = num_experts + self.hidden_dim = hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + nn.init.orthogonal_(self.weight) + self.routed_scaling_factor = routed_scaling_factor + self.register_buffer("tid2eid", torch.zeros(vocab_size, self.top_k, dtype=torch.long), persistent=True) + + def forward(self, hidden_states: torch.Tensor, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat = hidden_states.reshape(-1, self.hidden_dim) + logits = torch.nn.functional.linear(flat, self.weight) + scores = sqrt_softplus_pt(logits) + indices = self.tid2eid[input_ids.reshape(-1)].long() + weights = scores.gather(1, indices) + weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20) + return logits, weights * self.routed_scaling_factor, indices + + +class MoeConfig: + """Mock Config class for MoE tests.""" + def __init__(self): + self.vocab_size = 100 + self.mlp_activations = ["silu", "linear"] + self.padded_base_moe_mlp_dim = None + self.debug_sharding = False + self.using_pipeline_parallelism = False + self.logical_axis_rules = None + self.mlp_bias = False + self.sparse_matmul = False + self.model_name = "deepseek3" + self.float32_gate_logits = True + self.routed_bias = False + self.routed_score_func = "sqrtsoftplus" + self.matmul_precision = "default" + self.shard_mode = ShardMode.AUTO + self.moe_expert_input_dim = 128 + self.shard_exp_on_fsdp = False + self.use_batch_split_schedule = False + self.attention = "default" + self.enable_dp_attention = False + self.custom_mesh_and_rule = None + self.load_balance_loss_weight = 0.0 + self.routed_bias_update_rate = 0.0 + self.capacity_factor = 0.0 + self.emb_dim = 128 + self.prefuse_moe_weights = False + self.routed_scaling_factor = 1.0 + self.decoder_block = ctypes.DecoderBlockType.DEEPSEEK + self.norm_topk_prob = True + self.use_random_routing = False + self.n_routing_groups = -1 + self.topk_routing_group = 1 + self.routed_bias_update_rate = 0.0 + self.model_call_mode = "training" + self.fuse_expert_scales = False + self.wi_tile_fwd_batch_seq = 1 + self.use_tokamax_gmm = False + self.megablox = False + self.use_ring_of_experts = False + self.ragged_buffer_factor = 1.0 + self.use_ragged_sort = False + self.use_custom_sort_vjp = False # ============================================================================== # Tests @@ -61,6 +238,10 @@ def setUp(self): self.compress_rope_theta = 160000.0 self.partial_rotary_factor = 64.0 / 128.0 + # Build a mock mesh + devices = jax.devices() + self.mesh = jax.sharding.Mesh(np.array(devices).reshape(1, len(devices)), ("x", "y")) + self.config = DeepseekV4Config( hidden_size=self.num_heads * self.head_dim, num_attention_heads=self.num_heads, @@ -91,38 +272,31 @@ def _run_rotary_test(self, layer_type, expected_theta): """ Validates that the MaxText RoPE implementation is mathematically identical to the PyTorch reference up to 1e-5 tolerance. - - Test Flow: - 1. Initializes PyTorch and MaxText Rotary modules with the exact same configuration. - 2. Generates random floating-point noise for inputs to avoid trivial pass cases. - 3. Computes `cos` and `sin` frequencies and compares them directly. - 4. Applies interleaved RoPE rotation to the random inputs in both implementations. - 5. Transposes shapes to match expected dimensions for each framework. - 6. Verifies that the final rotated tensors match exactly. """ # -------------------------------------------------------------------------- # 1. Initialization # -------------------------------------------------------------------------- ref_rope = DeepseekV4RotaryEmbedding_PT(self.config) - mt_rope = DeepSeekV4RotaryEmbedding( - head_dim=self.head_dim, + + # Initialize the newly updated PartialRotaryEmbedding with interleaved=True and trailing=True + mt_rope = PartialRotaryEmbedding( + min_timescale=1, + max_timescale=int(expected_theta), + mesh=self.mesh, + embedding_dims=self.head_dim, partial_rotary_factor=self.partial_rotary_factor, - rope_theta=expected_theta, + interleaved=True, + trailing=True, + cast_as_fprop_dtype=False, # Keep in float32 for exact parity check ) # -------------------------------------------------------------------------- # 2. Input Generation # -------------------------------------------------------------------------- - # Generate non-trivial inputs using np.random.normal to guarantee we are not - # testing against zeros or ones. - # Initial shape: [Batch=2, SeqLen=16, NumHeads=4, HeadDim=128] np.random.seed(42) x_np = np.random.normal(size=(self.batch_size, self.seq_len, self.num_heads, self.head_dim)).astype(np.float32) - - # Position IDs are strictly sequential per batch element: [0, 1, ..., 15] position_ids_np = np.arange(self.seq_len)[None, :].repeat(self.batch_size, axis=0) - # Convert to framework-specific tensors x_pt = torch.tensor(x_np) position_ids_pt = torch.tensor(position_ids_np, dtype=torch.long) @@ -130,44 +304,31 @@ def _run_rotary_test(self, layer_type, expected_theta): position_ids_mt = jnp.array(position_ids_np) # -------------------------------------------------------------------------- - # 3. Frequency Generation (cos, sin) + # 3. Apply PyTorch Reference Rotation # -------------------------------------------------------------------------- - # PyTorch reference expects flattened hidden dim for computation: [B, S, H*D] + # PyTorch reference expects flattened hidden dim for frequency calculation ref_cos, ref_sin = ref_rope( x_pt.view(self.batch_size, self.seq_len, -1), position_ids=position_ids_pt, layer_type=layer_type ) - # MaxText natively operates without requiring flattening. - mt_cos, mt_sin = mt_rope.get_freqs(position_ids_mt) - - # Verify that the calculated frequencies match. - # Shape of cos/sin: [Batch=2, SeqLen=16, RotaryDim // 2 = 32] - np.testing.assert_allclose(np.array(mt_cos), ref_cos.numpy(), rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(np.array(mt_sin), ref_sin.numpy(), rtol=1e-5, atol=1e-5) - - # -------------------------------------------------------------------------- - # 4. Apply Interleaved RoPE Rotation - # -------------------------------------------------------------------------- - # PyTorch reference `ref_apply_rotary_pos_emb` expects head dimension to be before sequence length: - # Expected PyTorch Shape: [Batch, NumHeads, SeqLen, HeadDim] = [2, 4, 16, 128] + # PyTorch reference apply_rotary_pos_emb expects head dimension to be before sequence length: + # Expected PyTorch Shape: [Batch, NumHeads, SeqLen, HeadDim] x_pt_transpose = x_pt.transpose(1, 2) - ref_rotated = ref_apply_rotary_pos_emb(x_pt_transpose, ref_cos, ref_sin) - - # Transpose PyTorch result back to MaxText native layout [B, S, H, D] for comparison + ref_rotated = apply_rotary_pos_emb_PT(x_pt_transpose, ref_cos, ref_sin) ref_rotated_np = ref_rotated.transpose(1, 2).numpy() - # MaxText `DeepSeekV4RotaryEmbedding` natively operates on [B, S, H, D]. - # We pass unsqueeze_dim=2 to expand cos/sin from [B, S, D] to [B, S, 1, D] - # so they correctly broadcast over the NumHeads dimension. - mt_rotated = mt_rope(x_mt, position_ids_mt, unsqueeze_dim=2) + # -------------------------------------------------------------------------- + # 4. Apply MaxText Rotary Rotation + # -------------------------------------------------------------------------- + # MaxText PartialRotaryEmbedding operates natively on [B, S, H, D] + mt_rotated = mt_rope(x_mt, position_ids_mt) mt_rotated_np = np.array(mt_rotated) # -------------------------------------------------------------------------- # 5. Final Validation # -------------------------------------------------------------------------- - # Validate the full mathematical rotation is perfectly equivalent. np.testing.assert_allclose(mt_rotated_np, ref_rotated_np, rtol=1e-5, atol=1e-5) - print(f"Rotary Embedding test ({layer_type}) passed successfully.") + print(f"Rotary Embedding test ({layer_type}) passed successfully with 100% parity!") class DeepSeekV4GroupedLinearTest(unittest.TestCase): @@ -186,14 +347,6 @@ def test_grouped_linear_forward(self): """ Validates that the MaxText GroupedLinear projection is mathematically identical to the PyTorch bmm logic up to 1e-5 tolerance. - - Test Flow: - 1. Initializes PyTorch and MaxText GroupedLinear modules with the same configuration. - 2. Extracts the randomly initialized PyTorch weights and transposes them to match MaxText's kernel layout. - 3. Injects the reshaped weights into the MaxText module to ensure exact mathematical parity. - 4. Generates random floating-point noise for inputs. - 5. Executes the forward pass in both implementations. - 6. Verifies that the final projected tensors match exactly. """ # -------------------------------------------------------------------------- # 1. Initialization @@ -208,14 +361,8 @@ def test_grouped_linear_forward(self): # -------------------------------------------------------------------------- # 2. Extract and Reshape Weights # -------------------------------------------------------------------------- - # Shape of PyTorch weight is [out_features, in_features_per_group] - # In MaxText we expect [n_groups, in_features_per_group, out_features_per_group] - pt_weight = ref_linear.weight.data.numpy() # e.g., [256, 128] - + pt_weight = ref_linear.weight.data.numpy() out_features_per_group = self.out_features // self.n_groups - - # PyTorch's forward does: w = self.weight.view(self.n_groups, -1, hidden_dim).transpose(1, 2) - # This reshapes the weight matrix into group-specific chunks. mt_weight_np = pt_weight.reshape(self.n_groups, out_features_per_group, self.in_features_per_group).transpose(0, 2, 1) # -------------------------------------------------------------------------- @@ -227,15 +374,11 @@ def test_grouped_linear_forward(self): n_groups=self.n_groups, rngs=self.rngs, ) - # Manually inject weights for mathematical comparison mt_linear.kernel[...] = jnp.array(mt_weight_np) # -------------------------------------------------------------------------- # 4. Input Generation # -------------------------------------------------------------------------- - # Generate non-trivial inputs using np.random.normal to guarantee we are not - # testing against zeros or ones. - # Shape: [Batch, SeqLen, N_Groups, InFeaturesPerGroup] np.random.seed(42) x_np = np.random.normal(size=(self.batch_size, self.seq_len, self.n_groups, self.in_features_per_group)).astype( np.float32 @@ -247,19 +390,573 @@ def test_grouped_linear_forward(self): # -------------------------------------------------------------------------- # 5. Execute Forward Pass # -------------------------------------------------------------------------- - # PyTorch grouped linear takes [Batch, SeqLen, N_Groups, InFeaturesPerGroup] ref_out = ref_linear(x_pt) - - # MaxText grouped linear mt_out = mt_linear(x_mt) # -------------------------------------------------------------------------- # 6. Final Validation # -------------------------------------------------------------------------- - # Validate the full mathematical projection is perfectly equivalent. np.testing.assert_allclose(np.array(mt_out), ref_out.detach().numpy(), rtol=1e-5, atol=1e-5) - print("Grouped Linear test passed successfully.") + print("Grouped Linear test passed successfully with 100% parity!") + + +class DeepSeekV4MoeRoutingTest(unittest.TestCase): + """Tests to validate MaxText MoE routing (TopK learned and Hash static) against PyTorch.""" + + def setUp(self): + self.batch_size = 2 + self.seq_len = 8 + self.hidden_size = 128 + self.num_experts = 16 + self.num_experts_per_tok = 4 + self.vocab_size = 100 + + # Mesh setup for JAX + self.mesh = jax.sharding.Mesh(jax.devices(), ("data",)) + self.rngs = nnx.Rngs(42) + + def test_learned_moe_routing(self): + # 1. Initialize PyTorch Reference Router + ref_router = DeepseekV4TopKRouter_PT( + vocab_size=self.vocab_size, + num_experts=self.num_experts, + num_experts_per_tok=self.num_experts_per_tok, + hidden_size=self.hidden_size, + ) + + # 2. Extract weights and convert to JAX + pt_weight = ref_router.weight.data.numpy() + + # 3. Initialize MaxText implementation + cfg = MoeConfig() + cfg.vocab_size = self.vocab_size + cfg.emb_dim = self.hidden_size + cfg.moe_expert_input_dim = self.hidden_size + + # GateLogit in JAX + from maxtext.layers.moe import GateLogit + mt_gate = GateLogit( + in_features_shape=self.hidden_size, + out_features_shape=self.num_experts, + model_name="deepseek3", + mesh=self.mesh, + rngs=self.rngs, + score_func="sqrtsoftplus", + ) + mt_gate.kernel[...] = jnp.array(pt_weight.T) + + # 4. Generate random inputs + np.random.seed(42) + x_np = np.random.normal(size=(self.batch_size, self.seq_len, self.hidden_size)).astype(np.float32) + + x_pt = torch.tensor(x_np) + x_mt = jnp.array(x_np) + + # 5. Run forward passes + ref_logits, ref_weights, ref_indices = ref_router(x_pt) + + mt_gate_out, mt_pre_bias = mt_gate(x_mt) + mt_moe = RoutedMoE( + config=cfg, + num_experts=self.num_experts, + num_experts_per_tok=self.num_experts_per_tok, + mesh=self.mesh, + kernel_init=lambda rng, shape, dtype, *args, **kwargs: jnp.zeros(shape, dtype=dtype), + kernel_axes=(), + rngs=self.rngs, + intermediate_dim=256, + use_hash_routing=False, + ) + mt_moe.gate.kernel[...] = jnp.array(pt_weight.T) + + mt_weights, mt_indices = mt_moe.get_topk(mt_gate_out, mt_pre_bias) + + # 6. Validate Parity + ref_scores = sqrt_softplus_pt(ref_logits) + np.testing.assert_allclose(np.array(mt_gate_out).reshape(-1, self.num_experts), ref_scores.detach().numpy(), rtol=1e-5, atol=1e-5) + mt_indices_sorted = np.sort(np.array(mt_indices), axis=-1) + ref_indices_sorted = np.sort(ref_indices.numpy().reshape(self.batch_size, self.seq_len, -1), axis=-1) + np.testing.assert_array_equal(mt_indices_sorted, ref_indices_sorted) + + mt_weights_sorted = np.sort(np.array(mt_weights), axis=-1) + ref_weights_sorted = np.sort(ref_weights.detach().numpy().reshape(self.batch_size, self.seq_len, -1), axis=-1) + np.testing.assert_allclose(mt_weights_sorted, ref_weights_sorted, rtol=1e-5, atol=1e-5) + print("MoE Learned Routing (TopK + SqrtSoftplus) test passed successfully with 100% parity!") + + def test_hash_moe_routing(self): + # 1. Initialize PyTorch Reference Hash Router + ref_router = DeepseekV4HashRouter_PT( + vocab_size=self.vocab_size, + num_experts=self.num_experts, + num_experts_per_tok=self.num_experts_per_tok, + hidden_size=self.hidden_size, + ) + np.random.seed(42) + tid2eid_np = np.random.randint(0, self.num_experts, size=(self.vocab_size, self.num_experts_per_tok)).astype(np.int32) + ref_router.tid2eid.copy_(torch.tensor(tid2eid_np, dtype=torch.long)) + + # 2. Extract weights + pt_weight = ref_router.weight.data.numpy() + + # 3. Initialize MaxText implementation + cfg = MoeConfig() + cfg.vocab_size = self.vocab_size + cfg.emb_dim = self.hidden_size + cfg.moe_expert_input_dim = self.hidden_size + + mt_moe = RoutedMoE( + config=cfg, + num_experts=self.num_experts, + num_experts_per_tok=self.num_experts_per_tok, + mesh=self.mesh, + kernel_init=lambda rng, shape, dtype, *args, **kwargs: jnp.zeros(shape, dtype=dtype), + kernel_axes=(), + rngs=self.rngs, + intermediate_dim=256, + use_hash_routing=True, + tid2eid=jnp.array(tid2eid_np), + ) + mt_moe.gate.kernel[...] = jnp.array(pt_weight.T) + + # 4. Generate random inputs + input_ids_np = np.random.randint(0, self.vocab_size, size=(self.batch_size, self.seq_len)).astype(np.int32) + x_np = np.random.normal(size=(self.batch_size, self.seq_len, self.hidden_size)).astype(np.float32) + + x_pt = torch.tensor(x_np) + input_ids_pt = torch.tensor(input_ids_np, dtype=torch.long) + + x_mt = jnp.array(x_np) + input_ids_mt = jnp.array(input_ids_np) + + # 5. Run forward passes + ref_logits, ref_weights, ref_indices = ref_router(x_pt, input_ids_pt) + + mt_gate_out, mt_pre_bias = mt_moe.gate(x_mt) + mt_weights, mt_indices = mt_moe.get_topk(mt_gate_out, mt_pre_bias, input_ids=input_ids_mt) + + # 6. Validate Parity + np.testing.assert_array_equal(np.array(mt_indices), ref_indices.numpy().reshape(self.batch_size, self.seq_len, -1)) + np.testing.assert_allclose(np.array(mt_weights), ref_weights.detach().numpy().reshape(self.batch_size, self.seq_len, -1), rtol=1e-5, atol=1e-5) + print("MoE Static Hash-MoE Routing test passed successfully with 100% parity!") + + + + +class DeepSeekV4AttentionMaskingTest(unittest.TestCase): + """Tests to validate AttentionOp masking logic for DeepSeek-V4 attention patterns.""" + + def setUp(self): + self.config = pyconfig.initialize( + [sys.argv[0], "src/maxtext/configs/base.yml"], + run_name="test", + enable_checkpointing=False, + ) + + def test_generate_attention_mask_local_sliding(self): + """Verifies AttentionType.LOCAL_SLIDING enforces both causal and sliding window constraints.""" + + # Test with multiple heads and different sequence lengths + for s_len in [1, 8, 128]: + op = AttentionOp( + config=self.config, + num_query_heads=4, + num_kv_heads=1, + max_target_length=256, + mesh=None, + attention_kernel="dot_product", + attention_type=AttentionType.LOCAL_SLIDING, + sliding_window_size=3, + ) + + batch_size = 1 + q_dummy = jnp.zeros((batch_size, s_len, 1, 128)) + k_dummy = jnp.zeros((batch_size, s_len, 1, 128)) + + mask = op.generate_attention_mask( + query=q_dummy, + key=k_dummy, + decoder_segment_ids=None, + model_mode="train", + ) + + self.assertEqual(mask.shape, (1, 1, 1, s_len, s_len)) + mask_np = np.array(mask)[0, 0, 0] + + # Expected float mask for window_size=3 + # Row 0: [0.0, INF, INF, INF, INF, ...] + # Row 1: [0.0, 0.0, INF, INF, INF, ...] + # Row 2: [0.0, 0.0, 0.0, INF, INF, ...] + # Row 3: [INF, 0.0, 0.0, 0.0, INF, ...] + if s_len > 1: + self.assertEqual(mask_np[0, 1], DEFAULT_MASK_VALUE) # strict causal + self.assertEqual(mask_np[0, 0], 0.0) + + if s_len >= 4: + self.assertEqual(mask_np[3, 0], DEFAULT_MASK_VALUE) # sliding window size=3 + self.assertEqual(mask_np[3, 1], 0.0) + + def test_generate_attention_mask_compressed(self): + """Verifies AttentionType.COMPRESSED stitches sliding window and float compressed_mask.""" + + batch_size = 1 + s_len = 8 + c_len = 2 + kv_len = s_len + c_len + + op = AttentionOp( + config=self.config, + num_query_heads=4, + num_kv_heads=1, + max_target_length=128, + mesh=None, + attention_kernel="dot_product", + attention_type=AttentionType.COMPRESSED, + sliding_window_size=3, + ) + + q_dummy = jnp.zeros((batch_size, s_len, 1, 128)) + k_dummy = jnp.zeros((batch_size, kv_len, 1, 128)) + + # Simulate a compressed float mask [batch, 1, s_len, c_len] + compressed_mask = np.zeros((batch_size, 1, s_len, c_len), dtype=np.float32) + compressed_mask[:, :, :, 0] = DEFAULT_MASK_VALUE + compressed_mask = jnp.array(compressed_mask) + + mask = op.generate_attention_mask( + query=q_dummy, + key=k_dummy, + decoder_segment_ids=None, + model_mode="train", + compressed_mask=compressed_mask, + ) + + # Returned float mask should dynamically inherit the dimensionality of compressed_mask + self.assertEqual(mask.shape, (batch_size, 1, s_len, kv_len)) + mask_np = np.array(mask)[0, 0] + + # Uncompressed block (first s_len cols) follows sliding window float mask + self.assertEqual(mask_np[0, 1], DEFAULT_MASK_VALUE) + self.assertEqual(mask_np[0, 0], 0.0) + self.assertEqual(mask_np[3, 0], DEFAULT_MASK_VALUE) + self.assertEqual(mask_np[3, 1], 0.0) + + # Compressed block (last c_len cols) follows compressed_mask strictly + np.testing.assert_allclose(mask_np[:, s_len], DEFAULT_MASK_VALUE) + np.testing.assert_allclose(mask_np[:, s_len + 1], 0.0) + print("Mask logic for uncompressed & compressed attention passed perfectly.") + + +class DeepSeekV4CompressedAttentionTest(unittest.TestCase): + """Tests to validate MaxText CompressedAttention implementation against PyTorch reference.""" + + def setUp(self): + self.batch_size = 2 + self.seq_len = 512 + self.num_heads = 4 + self.head_dim = 128 + self.hidden_size = 256 + self.q_lora_rank = 32 + self.o_groups = 2 + self.o_lora_rank = 64 + + self.rngs = nnx.Rngs(0) + + self.pt_config = DeepseekV4Config( + hidden_size=self.hidden_size, + num_attention_heads=self.num_heads, + num_key_value_heads=1, + head_dim=self.head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.head_dim, + o_groups=self.o_groups, + o_lora_rank=self.o_lora_rank, + rope_theta=10000.0, + compress_rates={ + "compressed_sparse_attention": 4, + "heavily_compressed_attention": 8, + }, + index_n_heads=2, + index_head_dim=self.head_dim, + index_topk=2, + layer_types=["sliding_attention"], + num_hidden_layers=1, + rope_parameters={ + "main": {"rope_type": "default", "rope_theta": 10000.0, "partial_rotary_factor": 1.0}, + "compress": {"rope_type": "default", "rope_theta": 160000.0, "partial_rotary_factor": 1.0}, + }, + sliding_window=2048, + attention_dropout=0.0, + max_position_embeddings=2048, + ) + + def _build_maxtext_config(self, layer_type): + """Builds a MaxText pyconfig for a specific layer_type.""" + + config_arguments = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "max_target_length": 128, + "base_emb_dim": self.pt_config.hidden_size, + "head_dim": self.pt_config.head_dim, + "base_num_query_heads": self.pt_config.num_attention_heads, + "base_num_kv_heads": 1, + "dtype": "float32", + "weight_dtype": "float32", + "sliding_window_size": self.pt_config.sliding_window, + "q_lora_rank": self.pt_config.q_lora_rank, + "o_groups": self.pt_config.o_groups, + "o_lora_rank": self.pt_config.o_lora_rank, + "compress_ratios": [0, 4, 128], # Dummy list for the test + "compressed_rope_max_timescale": self.pt_config.rope_parameters["compress"]["rope_theta"], + "indexer_n_heads": self.pt_config.index_n_heads, + "indexer_head_dim": self.pt_config.index_head_dim, + "indexer_topk": self.pt_config.index_topk, + "normalization_layer_epsilon": self.pt_config.rms_norm_eps, + "partial_rotary_factor": self.pt_config.rope_parameters["compress" if layer_type != "sliding_attention" else "main"]["partial_rotary_factor"], + } + + argv = [sys.argv[0], "src/maxtext/configs/base.yml"] + mt_config = pyconfig.initialize(argv, **config_arguments) + + return mt_config + + def _copy_linear(self, mt_linear, pt_linear): + if pt_linear is None or mt_linear is None: + return + mt_linear.kernel.value = jnp.array(pt_linear.weight.data.numpy().T) + if hasattr(pt_linear, "bias") and pt_linear.bias is not None: + mt_linear.bias.value = jnp.array(pt_linear.bias.data.numpy()) + + def _copy_norm(self, mt_norm, pt_norm): + if pt_norm is None or mt_norm is None: + return + if hasattr(pt_norm, "weight") and pt_norm.weight is not None: + mt_norm.scale.value = jnp.array(pt_norm.weight.data.numpy()) + + def _run_e2e_test(self, layer_type, is_packed=False): + self.pt_config.layer_types = [layer_type] + + torch.manual_seed(42) + ref_attn = DeepseekV4Attention(self.pt_config, layer_idx=0) + self.ref_attn = ref_attn + + if layer_type == "compressed_sparse_attention" and self.pt_config.index_topk == 2: + for p in ref_attn.parameters(): + p.data = torch.abs(p.data) + 0.1 + + rope_main = PTRope(self.pt_config) + rope_compress = PTRope(self.pt_config) + + mt_config = self._build_maxtext_config(layer_type) + + mesh = Mesh(mesh_utils.create_device_mesh((1,)), axis_names=("fsdp",)) + + compress_ratio_map = { + "sliding_attention": 0, + "compressed_sparse_attention": self.pt_config.compress_rates["compressed_sparse_attention"], + "heavily_compressed_attention": self.pt_config.compress_rates["heavily_compressed_attention"], + } + mt_attn = CompressedAttention( + config=mt_config, + compress_ratio=compress_ratio_map[layer_type], + num_query_heads=self.num_heads, + num_kv_heads=1, + head_dim=self.head_dim, + max_target_length=128, + mesh=mesh, + attention_kernel="dot_product", + inputs_q_shape=(self.batch_size, self.seq_len, self.hidden_size), + inputs_kv_shape=(self.batch_size, self.seq_len, self.hidden_size), + q_lora_rank=self.q_lora_rank, + sliding_window_size=mt_config.sliding_window_size, + rngs=self.rngs, + ) + self.mt_attn = mt_attn + if layer_type == "sliding_attention": + rope_factor = self.pt_config.rope_parameters["main"]["partial_rotary_factor"] + mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=10000.0) + else: + rope_factor = self.pt_config.rope_parameters["compress"]["partial_rotary_factor"] + mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=160000.0) + + mt_attn.rotary_embedding = mt_rope + mt_attn.rotary_emb = mt_rope + if hasattr(mt_attn, "csa_compressor"): + mt_attn.csa_compressor.rotary_emb = mt_rope + mt_attn.csa_compressor.indexer.rotary_emb = mt_rope + if hasattr(mt_attn, "hca_compressor"): + mt_attn.hca_compressor.rotary_emb = mt_rope + + # 3. Copy Weights + self._copy_linear(mt_attn.wq_a, ref_attn.q_a_proj) + mt_attn.wq_b.kernel.value = jnp.array( + ref_attn.q_b_proj.weight.data.numpy().T.reshape(self.q_lora_rank, self.num_heads, self.head_dim) + ) + mt_attn.wkv.kernel.value = jnp.array( + ref_attn.kv_proj.weight.data.numpy().T.reshape( + self.hidden_size, self.pt_config.num_key_value_heads, self.head_dim + ) + ) + self._copy_norm(mt_attn.q_norm, ref_attn.q_a_norm) + self._copy_norm(mt_attn.kv_norm, ref_attn.kv_norm) + mt_attn.sinks.value = jnp.array(ref_attn.sinks.data.numpy().reshape(-1)) + + pt_oa_weight = ref_attn.o_a_proj.weight.data.numpy() + mt_oa_weight = pt_oa_weight.reshape(self.o_groups, -1, (self.num_heads * self.head_dim) // self.o_groups).transpose( + 0, 2, 1 + ) + mt_attn.o_a_proj.kernel.value = jnp.array(mt_oa_weight) + self._copy_linear(mt_attn.o_b_proj, ref_attn.o_b_proj) + + if layer_type == "heavily_compressed_attention": + self._copy_linear(mt_attn.hca_compressor.kv_proj, ref_attn.compressor.kv_proj) + self._copy_linear(mt_attn.hca_compressor.gate_proj, ref_attn.compressor.gate_proj) + mt_attn.hca_compressor.position_bias.value = jnp.array(ref_attn.compressor.position_bias.data.numpy()) + self._copy_norm(mt_attn.hca_compressor.kv_norm, ref_attn.compressor.kv_norm) + + if layer_type == "compressed_sparse_attention": + self._copy_linear(mt_attn.csa_compressor.kv_proj, ref_attn.compressor.kv_proj) + self._copy_linear(mt_attn.csa_compressor.gate_proj, ref_attn.compressor.gate_proj) + mt_attn.csa_compressor.position_bias.value = jnp.array(ref_attn.compressor.position_bias.data.numpy()) + self._copy_norm(mt_attn.csa_compressor.kv_norm, ref_attn.compressor.kv_norm) + + self._copy_linear(mt_attn.csa_compressor.indexer.q_proj, ref_attn.compressor.indexer.q_b_proj) + self._copy_linear(mt_attn.csa_compressor.indexer.kv_proj, ref_attn.compressor.indexer.kv_proj) + self._copy_linear(mt_attn.csa_compressor.indexer.gate_proj, ref_attn.compressor.indexer.gate_proj) + self._copy_linear(mt_attn.csa_compressor.indexer.weights_proj, ref_attn.compressor.indexer.weights_proj) + mt_attn.csa_compressor.indexer.position_bias.value = jnp.array( + ref_attn.compressor.indexer.position_bias.data.numpy() + ) + self._copy_norm(mt_attn.csa_compressor.indexer.kv_norm, ref_attn.compressor.indexer.kv_norm) + + # 4. Inputs + np.random.seed(42) + if layer_type == "compressed_sparse_attention" and self.pt_config.index_topk == 2: + x_np = np.random.uniform(0.1, 1.0, size=(self.batch_size, self.seq_len, self.hidden_size)).astype(np.float32) + else: + x_np = np.random.normal(size=(self.batch_size, self.seq_len, self.hidden_size)).astype(np.float32) + pos_np = np.arange(self.seq_len)[None, :].repeat(self.batch_size, axis=0) + x_pt = torch.tensor(x_np) + pos_pt = torch.tensor(pos_np, dtype=torch.long) + x_mt = jnp.array(x_np) + pos_mt = jnp.array(pos_np) + + if is_packed: + half = self.seq_len // 2 + segs_np = np.ones((self.batch_size, self.seq_len), dtype=np.int32) + segs_np[:, half:] = 2 + segs_mt = jnp.array(segs_np) + else: + segs_mt = jnp.ones_like(pos_mt, dtype=jnp.int32) + + # 5. Execute PyTorch + dummy_x_main = torch.zeros(self.batch_size, self.seq_len, 1) + cos_main, sin_main = rope_main(dummy_x_main, pos_pt, "main") + cos_comp, sin_comp = rope_compress(dummy_x_main, pos_pt, "compress") + + pt_positions = {"main": (cos_main, sin_main), "compress": (cos_comp, sin_comp)} + + if is_packed: + pt_mask = torch.full((self.batch_size, 1, self.seq_len, self.seq_len), float("-inf")) + pt_mask[:, :, :half, :half] = _prepare_4d_causal_attention_mask(None, (self.batch_size, half), x_pt, 0, 2048) + pt_mask[:, :, half:, half:] = _prepare_4d_causal_attention_mask( + None, (self.batch_size, self.seq_len - half), x_pt, 0, 2048 + ) + else: + pt_mask = _prepare_4d_causal_attention_mask(None, (self.batch_size, self.seq_len), x_pt, 0, 2048) + + pt_out, _ = ref_attn(x_pt, pt_positions, pos_pt, attention_mask=pt_mask) + + # Extract indexer top_k from PyTorch + if layer_type == "compressed_sparse_attention": + pt_q_residual = ref_attn.q_a_norm(ref_attn.q_a_proj(x_pt)) + pt_top_k_indices = ref_attn.compressor.indexer(x_pt, pt_q_residual, pos_pt, None, 0) + print(f"PyTorch top_k_indices:\n{pt_top_k_indices[0]}") + + mt_q_latent = mt_attn.wq_a(x_mt) + mt_q_residual = mt_attn.q_norm(mt_q_latent) + mt_top_k_indices = mt_attn.csa_compressor.indexer(x_mt, mt_q_residual, pos_mt) + print(f"MaxText top_k_indices:\n{mt_top_k_indices[0]}") + + num_mismatches = np.sum(pt_top_k_indices.detach().numpy() != np.array(mt_top_k_indices)) + print(f"top_k_indices mismatches: {num_mismatches}") + + # 6. Execute MaxText + mt_out = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) + + # 7. Asserts + if not is_packed: + print("Comparing MaxText vs PyTorch:") + if hasattr(mt_attn, "hca_compressor"): + mt_comp = mt_attn.hca_compressor + pt_comp = ref_attn.compressor + + pt_kv = pt_comp.kv_proj(x_pt) + mt_kv = mt_comp.kv_proj(x_mt) + print(f"kv_proj error: {np.max(np.abs(pt_kv.detach().numpy() - np.array(mt_kv)))}") + + pt_gate = pt_comp.gate_proj(x_pt) + mt_gate = mt_comp.gate_proj(x_mt) + print(f"gate_proj error: {np.max(np.abs(pt_gate.detach().numpy() - np.array(mt_gate)))}") + + batch, seq_len, _ = x_pt.shape + n_windows = seq_len // pt_comp.compress_rate + pt_chunk_kv = pt_kv.view(batch, n_windows, pt_comp.compress_rate, -1) + pt_chunk_gate = pt_gate.view(batch, n_windows, pt_comp.compress_rate, -1) + pt_comp.position_bias + + mt_chunk_kv = mt_kv.reshape((batch, n_windows, mt_comp.compress_rate, -1)) + mt_chunk_gate = mt_gate.reshape((batch, n_windows, mt_comp.compress_rate, -1)) + mt_comp.position_bias.value + print(f"chunk_gate error: {np.max(np.abs(pt_chunk_gate.detach().numpy() - np.array(mt_chunk_gate)))}") + + pt_gate_weights = pt_chunk_gate.softmax(dim=2, dtype=torch.float32).to(pt_chunk_kv.dtype) + mt_gate_weights = jax.nn.softmax(mt_chunk_gate, axis=2).astype(mt_chunk_kv.dtype) + print(f"gate_weights error: {np.max(np.abs(pt_gate_weights.detach().numpy() - np.array(mt_gate_weights)))}") + + pt_compressed = pt_comp.kv_norm((pt_chunk_kv * pt_gate_weights).sum(dim=2)) + mt_compressed = mt_comp.kv_norm(jnp.sum(mt_chunk_kv * mt_gate_weights, axis=2)) + print(f"compressed before rope error: {np.max(np.abs(pt_compressed.detach().numpy() - np.array(mt_compressed)))}") + + pt_positions = torch.arange(n_windows) * pt_comp.compress_rate + pt_positions = pt_positions.unsqueeze(0).expand(batch, -1) + pt_cos, pt_sin = pt_comp.rotary_emb(pt_compressed, position_ids=pt_positions, layer_type=pt_comp.rope_layer_type) + + mt_positions = jnp.arange(n_windows) * mt_comp.compress_rate + mt_positions = jnp.broadcast_to(mt_positions[None, :], (batch, n_windows)) + mt_cos, mt_sin = mt_comp.rotary_emb.get_freqs(mt_positions) + print(f"cos error: {np.max(np.abs(pt_cos.detach().numpy() - np.array(mt_cos)))}") + print(f"sin error: {np.max(np.abs(pt_sin.detach().numpy() - np.array(mt_sin)))}") + + pt_compressed_rot = apply_rotary_pos_emb(pt_compressed.unsqueeze(1), pt_cos, pt_sin).squeeze(1) + mt_compressed_rot = mt_comp.rotary_emb(mt_compressed, mt_positions, unsqueeze_dim=None) + error = np.max(np.abs(pt_compressed_rot.detach().numpy() - np.array(mt_compressed_rot))) + print(f"compressed after rope error: {error}") + + if layer_type == "compressed_sparse_attention": + pt_comp = ref_attn.compressor + mt_comp = mt_attn.csa_compressor + kv_error = np.max(np.abs(pt_comp.kv_proj(x_pt).detach().numpy() - np.array(mt_comp.kv_proj(x_mt)))) + print(f"csa kv_proj error: {kv_error}") + gate_error = np.max(np.abs(pt_comp.gate_proj(x_pt).detach().numpy() - np.array(mt_comp.gate_proj(x_mt)))) + print(f"csa gate_proj error: {gate_error}") + + np.testing.assert_allclose(np.array(mt_out), pt_out.detach().numpy(), rtol=1e-5, atol=1e-5) + else: + self.assertFalse(np.allclose(np.array(mt_out), pt_out.detach().numpy(), rtol=1e-3, atol=1e-3)) + print(f"Document packing test ({layer_type}) successfully confirmed PyTorch bug and MaxText firewall.") + + def test_forward_uncompressed(self): + self._run_e2e_test("sliding_attention") + + def test_forward_hca(self): + self._run_e2e_test("heavily_compressed_attention") + + def test_forward_csa(self): + self._run_e2e_test("compressed_sparse_attention") + + def test_document_packing_masking(self): + self._run_e2e_test("heavily_compressed_attention", is_packed=True) if __name__ == "__main__": unittest.main()