Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def init_fn(params):
count = None
else:
count = jnp.zeros([], jnp.int32)
return AddDecayedWeightsState(count=count)
return AddDecayedWeightsState(count=jnp.zeros((1,), dtype=jnp.int32))

def update_fn(updates: NestedTensor, state: AddDecayedWeightsState, params: NestedOptParam):
if params is None:
Expand Down Expand Up @@ -705,6 +705,71 @@ def adamw_decoupled_optimizer(

return chain(*tx)

def multisteps_optimizer(
*args, k_steps
) -> PartitionedGradientTransformation:
base = chain(*args)

class MultiStepsState(NamedTuple):
mini_step: int
gradient_step: int
inner_opt_state: Any
acc_grads: Any

def multisteps_init(model_params):
def _copy_zero(model_tree):
return jax.tree_map(lambda x: jnp.full_like(x, 0), model_tree)
return MultiStepsState(
mini_step = jnp.zeros([], dtype=jnp.int32),
gradient_step = jnp.zeros([], dtype=jnp.int32),
inner_opt_state = base.init(model_params),
acc_grads = _copy_zero(model_params))

def multisteps_partition(optimizer_model_param_specs):
return MultiStepsState(
mini_step = None,
gradient_step = None,
inner_opt_state = base.partition(optimizer_model_param_specs),
acc_grads = optimizer_model_param_specs) # TODO: check if this works was model_param_specs

# Copied from Optax Multistep
def multisteps_update(updates, state, params):
_acc_update = lambda x, y: x + y
# Note: we do not enclose variables to allow JAX to re-use memory buffers.
acc_grads = jax.tree_util.tree_map(
lambda upd, acc: _acc_update(upd, acc),
updates,
state.acc_grads,
)

final_updates, new_inner_state = base.update(
acc_grads, state.inner_opt_state, params=params
)

emit = state.mini_step == (k_steps - 1)
new_state = MultiStepsState(
mini_step=optax.safe_int32_increment(state.mini_step) % k_steps,
gradient_step=emit
* optax.safe_int32_increment(state.gradient_step)
+ (1 - emit) * state.gradient_step,
inner_opt_state=jax.tree_util.tree_map(
lambda st, nst: jnp.where(emit, nst, st),
state.inner_opt_state,
new_inner_state,
),
acc_grads=jax.tree_util.tree_map(
lambda ga: (1 - emit) * ga, acc_grads
),
)

final_updates = jax.tree_util.tree_map(
lambda ga: emit * ga, final_updates
)
return final_updates, new_state

return PartitionedGradientTransformation(
init=multisteps_init, update=multisteps_update, partition=multisteps_partition
)

def adam_optimizer(
learning_rate: schedule.Schedule,
Expand Down Expand Up @@ -1521,7 +1586,7 @@ def _init(param: OptParam):
)

return _AdastarState(
count=jnp.zeros([], dtype=jnp.int32), pps=jax.tree_util.tree_map(_init, params)
count=jnp.zeros((1,), dtype=jnp.int32), pps=jax.tree_util.tree_map(_init, params)
)

def update_fn(grads: NestedTensor, state: _AdastarState, params: NestedOptParam):
Expand Down
8 changes: 6 additions & 2 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,10 @@ def learner_config(
# Decay to this fraction of the peak_lr.
alpha=alpha,
)
optimizer_cfg = config_for_function(optimizers.chain).set(
# Change the optimizer requirement to optimizer multisteps instead of chain,
# this can be done through a config file change also by changing key-value pair:
# learner.optimizer.fn: 'axlearn.common.optimizers.chain'
optimizer_cfg = config_for_function(optimizers.multisteps_optimizer).set(
args=[
config_for_function(optimizers.clip_by_global_norm).set(max_norm=1),
config_for_function(optimizers.adamw_decoupled_optimizer).set(
Expand All @@ -298,7 +301,8 @@ def learner_config(
adam_update_transformation=None,
mu_dtype=jnp.float32
),
]
],
k_steps=4
)
return learner.Learner.default_config().set(optimizer=optimizer_cfg)

Expand Down