From 8d6a177ed191fa37b388212dc50a56907a2faa5d Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Sun, 24 May 2026 11:33:15 +0800 Subject: [PATCH] perf: optimize Muon with GramNS and fix fp16 spectral norm stability - Fix fp16 stability: Using float32 exclusively for the initial spectral normalization step prevents instability, allowing the rest of the algorithm to safely execute in fp16. - Integrate Gram Newton-Schulz: Computes iterations on the smaller Gram matrix. - Benchmarks show up to a 42% time reduction for heavily rectangular matrices (e.g., 8192x1024 drops from 58ms to 33ms) with no performance penalty on square shapes. --- modules/optimizer/muon.py | 80 ++++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index d5991dac..3cc8e453 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -10,25 +10,7 @@ from modules.commons.common_layers import AdamWLinear, AdamWConv1d -def get_bf16_support_map(): - bf16_support_map = {} - - if not torch.cuda.is_available(): - return bf16_support_map - - device_count = torch.cuda.device_count() - if device_count == 0: - return bf16_support_map - - for i in range(device_count): - device = torch.device(f'cuda:{i}') - major, minor = torch.cuda.get_device_capability(device) - bf16_support_map[device] = (major >= 8) - - return bf16_support_map - - -def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor: +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose @@ -41,11 +23,13 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng a, b, c = (3.4445, -4.7750, 2.0315) - X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) + X = G.to(torch.float32) # Ensure spectral norm is at most 1 X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) + X = X.to(torch.float16) + # Perform the NS iterations if X.size(-2) < X.size(-1): for _ in range(steps): @@ -61,6 +45,57 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor return X +def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]=[2]) -> Tensor: + """ + Refer to: + Gram Newton-Schulz: A Fast, Hardware-Aware Newton-Schulz Algorithm for Muon + Authors: Jack Zhang, Noah Amsel, Berlin Chen, Tri Dao + Blogpost: https://dao-ailab.github.io/blog/2026/gram-newton-schulz/ + + Gram Newton-Schulz iteration to compute the orthogonalization of G. + Mathematically identical to standard Newton-Schulz but computes iterating + on the smaller NxN Gram matrix to save up to 50% FLOPs. + """ + assert G.ndim == 3 + original_shape = G.shape + dtype = G.dtype + + X = G.to(torch.float32) + X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) + should_transpose = X.size(-2) > X.size(-1) + if should_transpose: + X = X.mT + X = X.to(torch.float16) + + a, b, c = (3.4445, -4.7750, 2.0315) + + if X.size(-2) != X.size(-1): + R = torch.bmm(X, X.mT) + Q = None + for i in range(steps): + if i in reset_iterations and i != 0: + X = torch.bmm(Q, X) + R = torch.bmm(X, X.mT) + Q = None + Z = torch.baddbmm(R, R, R, beta=b, alpha=c) + if i != 0 and i not in reset_iterations: + Q = torch.baddbmm(Q, Q, Z, beta=a, alpha=1.0) + else: + Q = Z.clone() + Q.diagonal(dim1=-2, dim2=-1).add_(a) + if i < steps - 1 and (i + 1) not in reset_iterations: + RZ = torch.baddbmm(R, R, Z, beta=a, alpha=1.0) + R = torch.baddbmm(RZ, Z, RZ, beta=a, alpha=1.0) + X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q) + else: + for _ in range(steps): + A = torch.bmm(X, X.mT) + B = torch.baddbmm(A, A, A, beta=b, alpha=c) + X = torch.baddbmm(X, B, X, beta=a, alpha=1.0) + + return X.to(dtype).view(original_shape) + + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -87,7 +122,6 @@ class Muon(torch.optim.Optimizer): def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) super().__init__(params, defaults) - self.bf16_support_map = get_bf16_support_map() @torch.no_grad() def step(self, closure=None): @@ -116,8 +150,8 @@ def step(self, closure=None): original_shape = g.shape if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) - use_bf16 = self.bf16_support_map.get(g.device, False) - g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) + g = gram_newton_schulz(g, steps=group["ns_steps"]) + if group["weight_decay"] > 0: torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5)