Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions modules/optimizer/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,7 @@
from modules.commons.common_layers import AdamWLinear, AdamWConv1d


def get_bf16_support_map():
bf16_support_map = {}

if not torch.cuda.is_available():
return bf16_support_map

device_count = torch.cuda.device_count()
if device_count == 0:
return bf16_support_map

for i in range(device_count):
device = torch.device(f'cuda:{i}')
major, minor = torch.cuda.get_device_capability(device)
bf16_support_map[device] = (major >= 8)

return bf16_support_map


def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor:
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
Expand All @@ -41,11 +23,13 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)

X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
X = G.to(torch.float32)

# Ensure spectral norm is at most 1
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)

X = X.to(torch.float16)

# Perform the NS iterations
if X.size(-2) < X.size(-1):
for _ in range(steps):
Expand All @@ -61,6 +45,57 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
return X


def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]=[2]) -> Tensor:
"""
Refer to:
Gram Newton-Schulz: A Fast, Hardware-Aware Newton-Schulz Algorithm for Muon
Authors: Jack Zhang, Noah Amsel, Berlin Chen, Tri Dao
Blogpost: https://dao-ailab.github.io/blog/2026/gram-newton-schulz/

Gram Newton-Schulz iteration to compute the orthogonalization of G.
Mathematically identical to standard Newton-Schulz but computes iterating
on the smaller NxN Gram matrix to save up to 50% FLOPs.
"""
assert G.ndim == 3
original_shape = G.shape
dtype = G.dtype

X = G.to(torch.float32)
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
should_transpose = X.size(-2) > X.size(-1)
if should_transpose:
X = X.mT
X = X.to(torch.float16)

a, b, c = (3.4445, -4.7750, 2.0315)

if X.size(-2) != X.size(-1):
R = torch.bmm(X, X.mT)
Q = None
for i in range(steps):
if i in reset_iterations and i != 0:
X = torch.bmm(Q, X)
R = torch.bmm(X, X.mT)
Q = None
Z = torch.baddbmm(R, R, R, beta=b, alpha=c)
if i != 0 and i not in reset_iterations:
Q = torch.baddbmm(Q, Q, Z, beta=a, alpha=1.0)
else:
Q = Z.clone()
Q.diagonal(dim1=-2, dim2=-1).add_(a)
if i < steps - 1 and (i + 1) not in reset_iterations:
RZ = torch.baddbmm(R, R, Z, beta=a, alpha=1.0)
R = torch.baddbmm(RZ, Z, RZ, beta=a, alpha=1.0)
X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q)
else:
for _ in range(steps):
A = torch.bmm(X, X.mT)
B = torch.baddbmm(A, A, A, beta=b, alpha=c)
X = torch.baddbmm(X, B, X, beta=a, alpha=1.0)

return X.to(dtype).view(original_shape)


class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
Expand All @@ -87,7 +122,6 @@ class Muon(torch.optim.Optimizer):
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
super().__init__(params, defaults)
self.bf16_support_map = get_bf16_support_map()

@torch.no_grad()
def step(self, closure=None):
Expand Down Expand Up @@ -116,8 +150,8 @@ def step(self, closure=None):
original_shape = g.shape
if g.ndim >= 4: # for the case of conv filters
g = g.view(g.size(0), g.size(1), -1)
use_bf16 = self.bf16_support_map.get(g.device, False)
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
g = gram_newton_schulz(g, steps=group["ns_steps"])

if group["weight_decay"] > 0:
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5)
Expand Down