diff --git a/dion/dion2.py b/dion/dion2.py index e79eab1..8837aea 100644 --- a/dion/dion2.py +++ b/dion/dion2.py @@ -59,9 +59,9 @@ def __init__( epsilon: float = 1e-8, adjust_lr: Optional[str] = "spectral_norm", flatten: bool = False, - use_gram_newton_schulz: bool = False, use_triton: bool = False, - use_polar_express: bool = False, + use_polar_express: bool = True, + use_gram_newton_schulz: bool = True, newton_schulz_func: Optional[Callable] = None, verbose: bool = False, ): diff --git a/dion/megabatch_base.py b/dion/megabatch_base.py index 92f8e2c..f4278f3 100644 --- a/dion/megabatch_base.py +++ b/dion/megabatch_base.py @@ -34,9 +34,9 @@ def __init__( distributed_mesh: Optional[Union[DeviceMesh, ProcessGroup]], algo_name: str, defaults: dict, - use_gram_newton_schulz: bool = False, + use_gram_newton_schulz: bool = True, use_triton: bool = False, - use_polar_express: bool = False, + use_polar_express: bool = True, newton_schulz_func: Optional[Callable] = None, ): super().__init__(params, defaults) diff --git a/dion/muon.py b/dion/muon.py index c24d5fc..a320052 100644 --- a/dion/muon.py +++ b/dion/muon.py @@ -61,9 +61,9 @@ def __init__( nesterov: bool = False, adjust_lr: Optional[str] = "spectral_norm", flatten: bool = False, - use_gram_newton_schulz: bool = False, + use_gram_newton_schulz: bool = True, use_triton: bool = False, - use_polar_express: bool = False, + use_polar_express: bool = True, newton_schulz_func: Optional[Callable] = None, ): if lr < 0.0: diff --git a/dion/normuon.py b/dion/normuon.py index 831842e..a419186 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -62,11 +62,11 @@ def __init__( cautious_wd: bool = False, epsilon: float = 1e-8, nesterov: bool = False, - adjust_lr: Optional[str] = "rms_norm", + adjust_lr: Optional[str] = "spectral_norm", flatten: bool = False, - use_gram_newton_schulz: bool = False, + use_gram_newton_schulz: bool = True, use_triton: bool = False, - use_polar_express: bool = False, + use_polar_express: bool = True, newton_schulz_func: Optional[Callable] = None, ): if lr < 0.0: