Skip to content

feat(layers): implement custom shape-aligned attention and MoE primit…#4200

Open
katyaoussar wants to merge 1 commit into
AI-Hypercomputer:mainfrom
katyaoussar:feature/deepseek-v4-custom-attention
Open

feat(layers): implement custom shape-aligned attention and MoE primit…#4200
katyaoussar wants to merge 1 commit into
AI-Hypercomputer:mainfrom
katyaoussar:feature/deepseek-v4-custom-attention

Conversation

@katyaoussar

Copy link
Copy Markdown

This PR implements a custom, JAX/Flax NNX integration of the DeepSeek-V4 attention mechanism and core primitives.

Here is a concise summary of what was done to the four core components:

1- RoPE (Rotary Embeddings): Implemented custom interleaved channel frequency pairing ([-x1, x0, -x3, x2]) and partial dimension rotation for precise token position encoding.
2- Grouped Linear: Created parallel, multi-group projection layers to efficiently mix attention head outputs in a single compilable step.
3- MoE (Mixture of Experts): Built the learned Top-K expert routing mechanism along with the custom SqrtSoftplus load-balancing loss to ensure stable training routing.
4- Attention Block: Engineered a unified, TPU-optimized module combining local sliding window attention, overlapping compressed sparse attention (CSA) with a causal indexer, and heavily compressed history attention (HCA) — using block-bias masking to avoid dynamic gather memory stalls.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant