Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class AttentionType(enum.Enum):
LOCAL_SLIDING = "local_sliding"
CHUNK = "chunk"
MLA = "mla"
COMPRESSED = "compressed"
FULL = "full"


Expand Down
6 changes: 6 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,12 @@ qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128

# Compressed Attention parameters
o_lora_rank: 0 # Output LoRA rank for Compressed Attention.
o_groups: 0 # Output groups for Compressed Attention.
compress_ratios: [] # Per-layer compression ratios (0, 4, 128, etc).
compressed_rope_max_timescale: 160_000 # If positive, used for Compressed Sparse/Heavy Attention.

# QK-Clip (Muon Clip) Configuration
use_qk_clip: false # Enable QK-Clip (supported in MLA with DotProduct or Tokamax Splash)
qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper)
Expand Down
16 changes: 16 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,17 @@ class MlaAttention(BaseModel):
v_head_dim: NonNegativeInt = Field(128, description="Dimension of V heads in MLA.")


class CompressedAttention(BaseModel):
"""Configuration for Compressed Attention."""

o_lora_rank: NonNegativeInt = Field(0, description="Output LoRA rank for Compressed Attention.")
o_groups: NonNegativeInt = Field(0, description="Output groups for Compressed Attention.")
compress_ratios: list[int] = Field(default_factory=list, description="Per-layer compression ratios (0, 4, 128, etc).")
compressed_rope_max_timescale: int = Field(
160000, description="If positive, used for Compressed Sparse/Heavy Attention."
)


class AttentionIndexer(BaseModel):
"""Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer."""

Expand Down Expand Up @@ -2269,6 +2280,7 @@ class MaxTextConfig(
# Attention Mechanisms
Attention,
MlaAttention,
CompressedAttention,
MoBa,
AttentionIndexer,
Llama4Attention,
Expand Down Expand Up @@ -3150,6 +3162,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
if self.share_kv_projections and self.attention_type == "mla":
raise ValueError("`share_kv_projections` is not compatible with `attention_type='mla'`.")

for val in self.compress_ratios:
if val != 0 and val < 4:
raise ValueError(f"compress_ratio must be 0 (disabled) or >= 4, got {val}")

if self.num_kv_shared_layers > 0:
if self.fused_qkv:
raise ValueError("`num_kv_shared_layers > 0` is not compatible with `fused_qkv`.")
Expand Down
Loading
Loading