From f16084daed474c151378598edb1866f3353a7457 Mon Sep 17 00:00:00 2001 From: Noah Amsel Date: Tue, 7 Apr 2026 08:16:12 -0700 Subject: [PATCH 1/3] normuon default: adjust_lr = "spectral_norm" --- dion/normuon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dion/normuon.py b/dion/normuon.py index 831842e..5782081 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -62,7 +62,7 @@ def __init__( cautious_wd: bool = False, epsilon: float = 1e-8, nesterov: bool = False, - adjust_lr: Optional[str] = "rms_norm", + adjust_lr: Optional[str] = "spectral_norm", flatten: bool = False, use_gram_newton_schulz: bool = False, use_triton: bool = False, From 199894b0059b1f0d81bd07cf59e30e644f5ffb92 Mon Sep 17 00:00:00 2001 From: Noah Amsel Date: Tue, 7 Apr 2026 08:47:38 -0700 Subject: [PATCH 2/3] default: use_gram_newton_schulz=True, use_polar_express=True --- dion/dion2.py | 4 ++-- dion/megabatch_base.py | 4 ++-- dion/muon.py | 4 ++-- dion/normuon.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dion/dion2.py b/dion/dion2.py index e79eab1..8837aea 100644 --- a/dion/dion2.py +++ b/dion/dion2.py @@ -59,9 +59,9 @@ def __init__( epsilon: float = 1e-8, adjust_lr: Optional[str] = "spectral_norm", flatten: bool = False, - use_gram_newton_schulz: bool = False, use_triton: bool = False, - use_polar_express: bool = False, + use_polar_express: bool = True, + use_gram_newton_schulz: bool = True, newton_schulz_func: Optional[Callable] = None, verbose: bool = False, ): diff --git a/dion/megabatch_base.py b/dion/megabatch_base.py index 92f8e2c..f4278f3 100644 --- a/dion/megabatch_base.py +++ b/dion/megabatch_base.py @@ -34,9 +34,9 @@ def __init__( distributed_mesh: Optional[Union[DeviceMesh, ProcessGroup]], algo_name: str, defaults: dict, - use_gram_newton_schulz: bool = False, + use_gram_newton_schulz: bool = True, use_triton: bool = False, - use_polar_express: bool = False, + use_polar_express: bool = True, newton_schulz_func: Optional[Callable] = None, ): super().__init__(params, defaults) diff --git a/dion/muon.py b/dion/muon.py index c24d5fc..a320052 100644 --- a/dion/muon.py +++ b/dion/muon.py @@ -61,9 +61,9 @@ def __init__( nesterov: bool = False, adjust_lr: Optional[str] = "spectral_norm", flatten: bool = False, - use_gram_newton_schulz: bool = False, + use_gram_newton_schulz: bool = True, use_triton: bool = False, - use_polar_express: bool = False, + use_polar_express: bool = True, newton_schulz_func: Optional[Callable] = None, ): if lr < 0.0: diff --git a/dion/normuon.py b/dion/normuon.py index 5782081..a419186 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -64,9 +64,9 @@ def __init__( nesterov: bool = False, adjust_lr: Optional[str] = "spectral_norm", flatten: bool = False, - use_gram_newton_schulz: bool = False, + use_gram_newton_schulz: bool = True, use_triton: bool = False, - use_polar_express: bool = False, + use_polar_express: bool = True, newton_schulz_func: Optional[Callable] = None, ): if lr < 0.0: From a5fd82e98dae0f3c63ede5902aa2041ec89c7635 Mon Sep 17 00:00:00 2001 From: John Langford Date: Wed, 8 Apr 2026 07:40:16 -0700 Subject: [PATCH 3/3] Add clear error message when gram-newton-schulz package is missing Since use_gram_newton_schulz now defaults to True, a missing package produces a confusing bare ImportError. Wrap the import in try/except with an actionable message, matching the existing use_triton pattern. --- dion/megabatch_base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dion/megabatch_base.py b/dion/megabatch_base.py index f4278f3..519e6ce 100644 --- a/dion/megabatch_base.py +++ b/dion/megabatch_base.py @@ -75,7 +75,14 @@ def __init__( ) self._newton_schulz_func = newton_schulz_func elif use_gram_newton_schulz: - from gram_newton_schulz import GramNewtonSchulz + try: + from gram_newton_schulz import GramNewtonSchulz + except ImportError: + raise ImportError( + "use_gram_newton_schulz=True requires the 'gram-newton-schulz' package, " + "which is not installed. " + "Install it with: pip install gram-newton-schulz" + ) use_polar_express = True _gns = GramNewtonSchulz( ns_use_kernels=use_triton,