diff --git a/dion/muon.py b/dion/muon.py index 92f444d..ad64722 100644 --- a/dion/muon.py +++ b/dion/muon.py @@ -45,6 +45,20 @@ class Muon(Optimizer): use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is `func(input: Tensor, epsilon: float) -> Tensor`. + skip_update_prob: SkipUpdate survival probability p ∈ (0, 1]. + At each step, each parameter matrix is independently kept with probability p and + skipped (zeroed out) with probability (1-p). Surviving updates are rescaled by + 1/p to keep the update unbiased in expectation. Moment buffers always update + densely regardless of skip. None (default) disables SkipUpdate. + See: https://arxiv.org/abs/2602.15322 + magma_tau: Magma temperature τ > 0. When set, enables Magma mode which replaces the + fixed 1/p rescaling with an adaptive EMA scale driven by momentum-gradient alignment: + ẽ_t = sigmoid(cossim(momentum_before, grad) / τ) + s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t + The surviving update is scaled by s_t instead of 1/p. This is intentionally biased + (no 1/s_t correction) as unbiased variants were found to be unstable. + Requires skip_update_prob to also be set. None (default) uses plain SkipUpdate scaling. + See: https://arxiv.org/abs/2602.15322 Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/ FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra @@ -65,6 +79,8 @@ def __init__( flatten: bool = False, use_triton: bool = False, newton_schulz_func: Optional[Callable] = None, + skip_update_prob: Optional[float] = None, + magma_tau: Optional[float] = None, ): # Check hyperparameters if lr < 0.0: @@ -77,6 +93,13 @@ def __init__( raise ValueError( f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." ) + # SkipUpdate / Magma: validate parameters + if skip_update_prob is not None and not (0.0 < skip_update_prob <= 1.0): + raise ValueError(f"skip_update_prob must be in (0, 1], got {skip_update_prob}") + if magma_tau is not None and magma_tau <= 0.0: + raise ValueError(f"magma_tau must be > 0, got {magma_tau}") + if magma_tau is not None and skip_update_prob is None: + raise ValueError("magma_tau requires skip_update_prob to be set") # Default arguments for each param group defaults = dict( @@ -92,6 +115,8 @@ def __init__( nesterov=nesterov, flatten=flatten, adjust_lr=adjust_lr, + skip_update_prob=skip_update_prob, # SkipUpdate: survival prob (None = disabled) + magma_tau=magma_tau, # Magma: temperature for adaptive scaling (None = disabled) ) super().__init__(params, defaults) @@ -180,6 +205,9 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: state["momentum"] = torch.zeros_like(param) if algo == "adamw": state["variance"] = torch.zeros_like(param) + if algo == "muon": + # Magma: per-param EMA scale, init=0.5 (neutral alignment) + state["magma_scale"] = torch.tensor(0.5, device=param.device, dtype=param.dtype) return state def _create_muon_tasks( @@ -215,6 +243,8 @@ def _create_muon_tasks( process_group=self._process_group, newton_schulz_func=self._newton_schulz_func, cautious_wd=group["cautious_wd"], + skip_update_prob=group["skip_update_prob"], # SkipUpdate: survival probability + magma_tau=group["magma_tau"], # Magma: temperature (None = plain SkipUpdate) ) # Create batches of parameters of size self._world_size @@ -224,6 +254,7 @@ def _create_muon_tasks( gradients = [p.grad for p in params] states = [self._get_or_initialize_state(p, algo_name) for p in params] momentums = [s["momentum"] for s in states] + magma_scales = [s["magma_scale"] for s in states] # Magma EMA scale per param # Get sharding state for DTensor is_batch_sharded = False @@ -283,12 +314,13 @@ def _create_muon_tasks( # As long as matrix dimensions are not sharded, each device will have whole matrices # Each device already has different matrices of the batch, so we can't parallelize further if is_batch_sharded and not is_matrix_sharded: - for x, g, m in zip(params, gradients, momentums): + for x, g, m, s in zip(params, gradients, momentums, magma_scales): yield AsyncTask( muon_update_batch_async( X=[x], G=[g], M=[m], + S=[s], shard_dim=None, # No sharded matrix dim **muon_update_args, ) @@ -300,6 +332,7 @@ def _create_muon_tasks( X=pad_batch(params, self._world_size), G=pad_batch(gradients, self._world_size), M=pad_batch(momentums, self._world_size), + S=pad_batch(magma_scales, self._world_size), shard_dim=sharded_tensor_dim, **muon_update_args, ) @@ -394,6 +427,7 @@ def muon_update_batch_async( X: List[Tensor], # Model weights (modified in place) G: List[Tensor], # Gradient M: List[Tensor], # Momentum buffer (modified in place) + S: List[Tensor], # Magma EMA scale buffer, scalar per param (modified in place) lr: Tensor, # Learning rate (scalar tensor) momentum: Tensor, # Momentum factor (scalar tensor) weight_decay: Tensor, # Weight decay (scalar tensor) @@ -407,6 +441,8 @@ def muon_update_batch_async( process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, cautious_wd: bool = False, + skip_update_prob: Optional[float] = None, # SkipUpdate: survival probability (None = disabled) + magma_tau: Optional[float] = None, # Magma: temperature for adaptive scaling (None = disabled) ) -> Generator[None, None, None]: """ Batched version of Muon update. Batch size should be equal to number of GPUs. @@ -417,10 +453,17 @@ def muon_update_batch_async( assert len(X) == len(G) assert len(X) == len(M) + # Magma: snapshot momentum before it's updated, for cosine similarity with current grad. + # muon_update_pre_orthogonalize updates M in-place, so we must clone beforehand. + G_local = to_local(G) + M_local = to_local(M) + if magma_tau is not None: + M_before = [m.clone() for m in M_local] + # Update momentum and compute the inputs for orthogonalization U = muon_update_pre_orthogonalize( - G=to_local(G), - M=to_local(M), + G=G_local, + M=M_local, momentum=momentum, nesterov=nesterov, ) @@ -510,6 +553,34 @@ def muon_update_batch_async( epsilon=epsilon, ) + # SkipUpdate / Magma: stochastic block masking per parameter matrix. + # Moments always update densely (above); only the final update direction is masked. + # Reference: "On Surprising Effectiveness of Masking Updates in Adaptive Optimizers". + if skip_update_prob is not None and skip_update_prob < 1.0: + # muon_update_newton_schulz returns a Tensor, not a list; U may be a list already + U = list(U) if not isinstance(U, list) else U + S_local = to_local(S) + + for i in range(len(U)): + # Sample one Bernoulli scalar per parameter block (not per element) + keep = torch.bernoulli(torch.tensor(skip_update_prob, device=U[i].device)) + + if magma_tau is not None: + # Magma: adaptive scale via momentum-gradient cosine similarity. + # ẽ_t = sigmoid(cossim(μ_t_before, g_t) / τ) + # s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t (EMA, updated in-place) + mu = M_before[i].flatten().float() + g = G_local[i].flatten().float() + cos = torch.dot(mu, g) / (mu.norm() * g.norm() + 1e-8) + e_tilde = torch.sigmoid(cos / magma_tau) + S_local[i].mul_(0.9).add_(e_tilde * 0.1) # EMA update in-place + scale = S_local[i] + else: + # Plain SkipUpdate: fixed unbiasing rescale of 1/p + scale = 1.0 / skip_update_prob + + U[i] = U[i] * (keep * scale) # zero-out or scale entire matrix + # Compute scaled learning rate # Do this before to_local(X) because we use the full tensor shape, not the shard shape if adjust_lr is None: diff --git a/dion/normuon.py b/dion/normuon.py index 2e13f6d..5345451 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -43,6 +43,21 @@ class NorMuon(DistributedOrthoBase): use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is ``func(input: Tensor, epsilon: float) -> Tensor``. + skip_update_prob: SkipUpdate survival probability p ∈ (0, 1]. + At each step, each parameter matrix is independently kept with probability p and + skipped (zeroed out) with probability (1-p). Surviving updates are rescaled by + 1/p to keep the update unbiased in expectation. Moment buffers (momentum and + variance_neuron) always update densely regardless of skip. + None (default) disables SkipUpdate. + See: https://arxiv.org/abs/2602.15322 + magma_tau: Magma temperature τ > 0. When set, enables Magma mode which replaces the + fixed 1/p rescaling with an adaptive EMA scale driven by momentum-gradient alignment: + ẽ_t = sigmoid(cossim(momentum_before, grad) / τ) + s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t + The surviving update is scaled by s_t instead of 1/p. This is intentionally biased + (no 1/s_t correction) as unbiased variants were found to be unstable. + Requires skip_update_prob to also be set. None (default) uses plain SkipUpdate scaling. + See: https://arxiv.org/abs/2602.15322 Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/ FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra @@ -65,6 +80,8 @@ def __init__( flatten: bool = False, use_triton: bool = False, newton_schulz_func: Optional[Callable] = None, + skip_update_prob: Optional[float] = None, + magma_tau: Optional[float] = None, ): if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") @@ -78,6 +95,13 @@ def __init__( raise ValueError( f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." ) + # SkipUpdate / Magma: validate parameters + if skip_update_prob is not None and not (0.0 < skip_update_prob <= 1.0): + raise ValueError(f"skip_update_prob must be in (0, 1], got {skip_update_prob}") + if magma_tau is not None and magma_tau <= 0.0: + raise ValueError(f"magma_tau must be > 0, got {magma_tau}") + if magma_tau is not None and skip_update_prob is None: + raise ValueError("magma_tau requires skip_update_prob to be set") defaults = dict( lr=lr, @@ -93,6 +117,8 @@ def __init__( nesterov=nesterov, flatten=flatten, adjust_lr=adjust_lr, + skip_update_prob=skip_update_prob, # SkipUpdate: survival prob (None = disabled) + magma_tau=magma_tau, # Magma: temperature for adaptive scaling (None = disabled) ) super().__init__( params, distributed_mesh, "normuon", defaults, @@ -103,6 +129,8 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: state = super()._get_or_initialize_state(param, algo) if algo == self._algo_name and "variance_neuron" not in state: state["variance_neuron"] = torch.zeros_like(param[..., 0:1]) + # Magma: per-param EMA scale, init=0.5 (neutral alignment) + state["magma_scale"] = torch.tensor(0.5, device=param.device, dtype=param.dtype) return state def _get_shard_info(self, param: Tensor, group: dict): @@ -146,6 +174,8 @@ def _create_ortho_tasks( process_group=self._process_group, newton_schulz_func=self._newton_schulz_func, cautious_wd=group["cautious_wd"], + skip_update_prob=group["skip_update_prob"], # SkipUpdate: survival probability + magma_tau=group["magma_tau"], # Magma: temperature (None = plain SkipUpdate) ) shape_groups: dict[tuple, list] = defaultdict(list) @@ -158,6 +188,7 @@ def _create_ortho_tasks( states = [self._get_or_initialize_state(p, self._algo_name) for p in params] momentums = [s["momentum"] for s in states] variances_neuron = [s["variance_neuron"] for s in states] + magma_scales = [s["magma_scale"] for s in states] # Magma EMA scale per param is_batch_sharded, is_matrix_sharded, sharded_tensor_dim = ( self._get_shard_info(params[0], group) @@ -173,6 +204,7 @@ def _create_ortho_tasks( G=gradients, M=momentums, V=variances_neuron, + S=magma_scales, shard_dim=sharded_tensor_dim, **megabatch_args, ) @@ -184,6 +216,7 @@ def normuon_update_megabatch_async( G: List[Tensor], M: List[Tensor], V: List[Tensor], + S: List[Tensor], # Magma EMA scale buffer, scalar per param (modified in place) lr: Tensor, momentum: Tensor, muon_beta2: Tensor, @@ -198,6 +231,8 @@ def normuon_update_megabatch_async( process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, cautious_wd: bool = False, + skip_update_prob: Optional[float] = None, # SkipUpdate: survival probability (None = disabled) + magma_tau: Optional[float] = None, # Magma: temperature for adaptive scaling (None = disabled) ) -> Generator[None, None, None]: """ Mega-batched NorMuon update: processes ALL same-shape parameters in one @@ -206,9 +241,16 @@ def normuon_update_megabatch_async( N = len(X) assert N == len(G) == len(M) == len(V) + # Magma: snapshot momentum before it's updated, for cosine similarity with current grad. + # muon_update_pre_orthogonalize updates M in-place, so we must clone beforehand. + G_local = to_local(G) + M_local = to_local(M) + if magma_tau is not None: + M_before = [m.clone() for m in M_local] + # Pre-orthogonalize: update momentum U = muon_update_pre_orthogonalize( - G=to_local(G), M=to_local(M), momentum=momentum, nesterov=nesterov, + G=G_local, M=M_local, momentum=momentum, nesterov=nesterov, ) # Convert shard_dim to negative for comm_dim @@ -235,6 +277,33 @@ def normuon_update_megabatch_async( V_local[i].copy_(V_stacked[i]) U = [U_stacked[i] for i in range(N)] + # SkipUpdate / Magma: stochastic block masking per parameter matrix. + # Moments always update densely (above); only the final update direction is masked. + # Reference: "On Surprising Effectiveness of Masking Updates in Adaptive Optimizers". + if skip_update_prob is not None and skip_update_prob < 1.0: + U = list(U) + S_local = to_local(S) + + for i in range(len(U)): + # Sample one Bernoulli scalar per parameter block (not per element) + keep = torch.bernoulli(torch.tensor(skip_update_prob, device=U[i].device)) + + if magma_tau is not None: + # Magma: adaptive scale via momentum-gradient cosine similarity. + # ẽ_t = sigmoid(cossim(μ_t_before, g_t) / τ) + # s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t (EMA, updated in-place) + mu = M_before[i].flatten().float() + g = G_local[i].flatten().float() + cos = torch.dot(mu, g) / (mu.norm() * g.norm() + 1e-8) + e_tilde = torch.sigmoid(cos / magma_tau) + S_local[i].mul_(0.9).add_(e_tilde * 0.1) # EMA update in-place + scale = S_local[i] + else: + # Plain SkipUpdate: fixed unbiasing rescale of 1/p + scale = 1.0 / skip_update_prob + + U[i] = U[i] * (keep * scale) # zero-out or scale entire matrix + # Compute scaled learning rate if adjust_lr is None: adjusted_lr = lr