diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index cf512524b..85df68a07 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -442,7 +442,7 @@ def init_fn(params): count = None else: count = jnp.zeros([], jnp.int32) - return AddDecayedWeightsState(count=count) + return AddDecayedWeightsState(count=jnp.zeros((1,), dtype=jnp.int32)) def update_fn(updates: NestedTensor, state: AddDecayedWeightsState, params: NestedOptParam): if params is None: @@ -705,6 +705,71 @@ def adamw_decoupled_optimizer( return chain(*tx) +def multisteps_optimizer( + *args, k_steps +) -> PartitionedGradientTransformation: + base = chain(*args) + + class MultiStepsState(NamedTuple): + mini_step: int + gradient_step: int + inner_opt_state: Any + acc_grads: Any + + def multisteps_init(model_params): + def _copy_zero(model_tree): + return jax.tree_map(lambda x: jnp.full_like(x, 0), model_tree) + return MultiStepsState( + mini_step = jnp.zeros([], dtype=jnp.int32), + gradient_step = jnp.zeros([], dtype=jnp.int32), + inner_opt_state = base.init(model_params), + acc_grads = _copy_zero(model_params)) + + def multisteps_partition(optimizer_model_param_specs): + return MultiStepsState( + mini_step = None, + gradient_step = None, + inner_opt_state = base.partition(optimizer_model_param_specs), + acc_grads = optimizer_model_param_specs) # TODO: check if this works was model_param_specs + + # Copied from Optax Multistep + def multisteps_update(updates, state, params): + _acc_update = lambda x, y: x + y + # Note: we do not enclose variables to allow JAX to re-use memory buffers. + acc_grads = jax.tree_util.tree_map( + lambda upd, acc: _acc_update(upd, acc), + updates, + state.acc_grads, + ) + + final_updates, new_inner_state = base.update( + acc_grads, state.inner_opt_state, params=params + ) + + emit = state.mini_step == (k_steps - 1) + new_state = MultiStepsState( + mini_step=optax.safe_int32_increment(state.mini_step) % k_steps, + gradient_step=emit + * optax.safe_int32_increment(state.gradient_step) + + (1 - emit) * state.gradient_step, + inner_opt_state=jax.tree_util.tree_map( + lambda st, nst: jnp.where(emit, nst, st), + state.inner_opt_state, + new_inner_state, + ), + acc_grads=jax.tree_util.tree_map( + lambda ga: (1 - emit) * ga, acc_grads + ), + ) + + final_updates = jax.tree_util.tree_map( + lambda ga: emit * ga, final_updates + ) + return final_updates, new_state + + return PartitionedGradientTransformation( + init=multisteps_init, update=multisteps_update, partition=multisteps_partition + ) def adam_optimizer( learning_rate: schedule.Schedule, @@ -1521,7 +1586,7 @@ def _init(param: OptParam): ) return _AdastarState( - count=jnp.zeros([], dtype=jnp.int32), pps=jax.tree_util.tree_map(_init, params) + count=jnp.zeros((1,), dtype=jnp.int32), pps=jax.tree_util.tree_map(_init, params) ) def update_fn(grads: NestedTensor, state: _AdastarState, params: NestedOptParam): diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 9810849bb..a2ab4b323 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -284,7 +284,10 @@ def learner_config( # Decay to this fraction of the peak_lr. alpha=alpha, ) - optimizer_cfg = config_for_function(optimizers.chain).set( + # Change the optimizer requirement to optimizer multisteps instead of chain, + # this can be done through a config file change also by changing key-value pair: + # learner.optimizer.fn: 'axlearn.common.optimizers.chain' + optimizer_cfg = config_for_function(optimizers.multisteps_optimizer).set( args=[ config_for_function(optimizers.clip_by_global_norm).set(max_norm=1), config_for_function(optimizers.adamw_decoupled_optimizer).set( @@ -298,7 +301,8 @@ def learner_config( adam_update_transformation=None, mu_dtype=jnp.float32 ), - ] + ], + k_steps=4 ) return learner.Learner.default_config().set(optimizer=optimizer_cfg)