From 5acb2907bc28d05f4af73e64fec3219e54c8d905 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 25 Mar 2024 16:55:58 +0000 Subject: [PATCH 1/6] gradient accumulation using optax multisteps --- axlearn/common/learner.py | 84 +++++++++++++++++++++++++++++++++--- axlearn/common/optimizers.py | 14 +++++- 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/axlearn/common/learner.py b/axlearn/common/learner.py index a19e9c718..0a7dbd0ba 100644 --- a/axlearn/common/learner.py +++ b/axlearn/common/learner.py @@ -9,7 +9,8 @@ import dataclasses import enum -from typing import Mapping, Optional, Sequence, Tuple +from typing import Mapping, Optional, Sequence, Tuple, Any, NamedTuple +# from typing import , Callable, Dict, List, , Optional, Sequence, Tuple, Union import jax import optax @@ -25,7 +26,7 @@ ) from axlearn.common.module import Module from axlearn.common.optimizer_base import NestedOptParam, PartitionedGradientTransformation -from axlearn.common.optimizers import param_ema +from axlearn.common.optimizers import param_ema, MultiStepsState from axlearn.common.utils import ( NestedPartitionSpec, NestedTensor, @@ -163,7 +164,21 @@ def create_state_partition_specs( self, model_param_specs: NestedParameterSpec ) -> NestedPartitionSpec: optimizer_model_param_specs = self._get_optimizer_model_params(model_param_specs) - partition_state = dict(optimizer=self.optimizer.partition(optimizer_model_param_specs)) + # mystate = optax.MultiStepsState(None, None, None, None, None) + # dict( + # optimizer=optax.MultiStepsState + # ) + print("In create_state_partition_specs\n") + + partition_state = dict( + optimizer=MultiStepsState( + mini_step = None, + gradient_step = None, + inner_opt_state = self.optimizer.partition(optimizer_model_param_specs), + acc_grads = model_param_specs) + ) + + # partition_state = dict(optimizer=self.optimizer.partition(optimizer_model_param_specs)) if self.config.ema.decay is not None: partition_state["ema"] = self.ema.partition(model_param_specs) return partition_state @@ -179,8 +194,25 @@ def _get_optimizer_model_params(self, model_params: NestedOptParam) -> NestedOpt def init(self, model_params: NestedOptParam) -> NestedTensor: update_types = self._update_types(model_params) register_per_param_settings(update_types, description="learner_update_type") + print("Optimizer state init") + + # print("Multistep init") + # tmp = self.optimizer.partition + # multisteps = optax.MultiSteps(self.optimizer, 2) + # self.optimizer = PartitionedGradientTransformation( + # init=multisteps.init, update=multisteps.update, partition=tmp + # ) + + def _copy_zero(model_tree): + return jax.tree_map(lambda x: jnp.full_like(x, 0), model_tree) + + grad_model_params = self._get_optimizer_model_params(model_params) state = dict( - optimizer=self.optimizer.init(self._get_optimizer_model_params(model_params)), + optimizer=MultiStepsState( + mini_step = jnp.zeros([], dtype=jnp.int32), + gradient_step = jnp.zeros([], dtype=jnp.int32), + inner_opt_state = self.optimizer.init(grad_model_params), + acc_grads = _copy_zero(grad_model_params)) ) if self.config.ema.decay is not None: state["ema"] = self.ema.init(model_params) @@ -227,11 +259,53 @@ def update( """ cfg = self.config optimizer_model_params = self._get_optimizer_model_params(model_params) - optimizer_parameter_updates, optimizer_state = self.optimizer.update( + print("params in self.optimizer.update", optimizer_model_params) + k_steps = 4 + _acc_update = lambda x, y: x + y + + # Note: we do not enclose variables to allow JAX to re-use memory buffers. + def _do_update(updates, state, params): + acc_grads = jax.tree_util.tree_map( + lambda upd, acc: _acc_update(upd, acc), + updates, + state.acc_grads, + ) + + final_updates, new_inner_state = self.optimizer.update( + acc_grads, state.inner_opt_state, params=params + ) + + emit = state.mini_step == (k_steps - 1) + new_state = MultiStepsState( + mini_step=optax.safe_int32_increment(state.mini_step) % k_steps, + gradient_step=emit + * optax.safe_int32_increment(state.gradient_step) + + (1 - emit) * state.gradient_step, + inner_opt_state=jax.tree_util.tree_map( + lambda st, nst: jnp.where(emit, nst, st), + state.inner_opt_state, + new_inner_state, + ), + acc_grads=jax.tree_util.tree_map( + lambda ga: (1 - emit) * ga, acc_grads + ), + ) + + final_updates = jax.tree_util.tree_map( + lambda ga: emit * ga, final_updates + ) + return final_updates, new_state + + optimizer_parameter_updates, optimizer_state = _do_update( gradients, state=self.state["optimizer"], params=optimizer_model_params, ) + # optimizer_parameter_updates, optimizer_state = self.optimizer.update( + # gradients, + # state=self.state["optimizer"], + # params=optimizer_model_params, + # ) self.add_state_update("optimizer", optimizer_state) if cfg.enable_per_variable_summaries: param_rms = jax.tree_util.tree_map( diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index cf512524b..39e46a200 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -113,7 +113,8 @@ def init_fn(params: NestedOptParam) -> NestedTensor: def update_fn( updates: optax.Updates, state: optax.OptState, params: NestedOptParam ) -> Tuple[optax.Updates, optax.OptState]: - return base.update(updates, state, opt_param_values(params)) + print("params in update_fn", params) + return base.update(updates, state, opt_param_values(params)) # TODO: apoorvgu params wrong here return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn) @@ -684,6 +685,8 @@ def adamw_decoupled_optimizer( A PartitionedGradientTransformation representing a decoupled AdamW optimizer with parameter scaling. """ + # optax.MultiSteps(optimizer, every_k_schedule=3) + # tx = [adam_partition(optax.MultiSteps(optax.scale_by_adam(b1=b1, b2=b2, eps=eps, mu_dtype=mu_dtype), every_k_schedule=2))] tx = [adam_partition(optax.scale_by_adam(b1=b1, b2=b2, eps=eps, mu_dtype=mu_dtype))] if adam_update_transformation is not None: tx.append(maybe_instantiate(adam_update_transformation)) @@ -706,6 +709,13 @@ def adamw_decoupled_optimizer( return chain(*tx) +class MultiStepsState(NamedTuple): + mini_step: int + gradient_step: int + inner_opt_state: Any + acc_grads: Any + + def adam_optimizer( learning_rate: schedule.Schedule, *, @@ -1521,7 +1531,7 @@ def _init(param: OptParam): ) return _AdastarState( - count=jnp.zeros([], dtype=jnp.int32), pps=jax.tree_util.tree_map(_init, params) + count=jnp.zeros([], jnp.int32), pps=jax.tree_util.tree_map(_init, params) ) def update_fn(grads: NestedTensor, state: _AdastarState, params: NestedOptParam): From a2c3132889a76c1f134ec039282e8fcea6598ab9 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 25 Mar 2024 19:43:45 +0000 Subject: [PATCH 2/6] refactored multisteps to localize changes to single elegant optimizer function --- axlearn/common/learner.py | 82 +----------- axlearn/common/optimizers.py | 171 ++++++++++++++++++++++++- axlearn/experiments/text/gpt/common.py | 6 +- 3 files changed, 175 insertions(+), 84 deletions(-) diff --git a/axlearn/common/learner.py b/axlearn/common/learner.py index 0a7dbd0ba..39e67c096 100644 --- a/axlearn/common/learner.py +++ b/axlearn/common/learner.py @@ -26,7 +26,7 @@ ) from axlearn.common.module import Module from axlearn.common.optimizer_base import NestedOptParam, PartitionedGradientTransformation -from axlearn.common.optimizers import param_ema, MultiStepsState +from axlearn.common.optimizers import param_ema from axlearn.common.utils import ( NestedPartitionSpec, NestedTensor, @@ -164,21 +164,7 @@ def create_state_partition_specs( self, model_param_specs: NestedParameterSpec ) -> NestedPartitionSpec: optimizer_model_param_specs = self._get_optimizer_model_params(model_param_specs) - # mystate = optax.MultiStepsState(None, None, None, None, None) - # dict( - # optimizer=optax.MultiStepsState - # ) - print("In create_state_partition_specs\n") - - partition_state = dict( - optimizer=MultiStepsState( - mini_step = None, - gradient_step = None, - inner_opt_state = self.optimizer.partition(optimizer_model_param_specs), - acc_grads = model_param_specs) - ) - - # partition_state = dict(optimizer=self.optimizer.partition(optimizer_model_param_specs)) + partition_state = self.optimizer.partition(optimizer_model_param_specs) if self.config.ema.decay is not None: partition_state["ema"] = self.ema.partition(model_param_specs) return partition_state @@ -194,26 +180,8 @@ def _get_optimizer_model_params(self, model_params: NestedOptParam) -> NestedOpt def init(self, model_params: NestedOptParam) -> NestedTensor: update_types = self._update_types(model_params) register_per_param_settings(update_types, description="learner_update_type") - print("Optimizer state init") - - # print("Multistep init") - # tmp = self.optimizer.partition - # multisteps = optax.MultiSteps(self.optimizer, 2) - # self.optimizer = PartitionedGradientTransformation( - # init=multisteps.init, update=multisteps.update, partition=tmp - # ) - - def _copy_zero(model_tree): - return jax.tree_map(lambda x: jnp.full_like(x, 0), model_tree) - grad_model_params = self._get_optimizer_model_params(model_params) - state = dict( - optimizer=MultiStepsState( - mini_step = jnp.zeros([], dtype=jnp.int32), - gradient_step = jnp.zeros([], dtype=jnp.int32), - inner_opt_state = self.optimizer.init(grad_model_params), - acc_grads = _copy_zero(grad_model_params)) - ) + state = self.optimizer.init(grad_model_params) if self.config.ema.decay is not None: state["ema"] = self.ema.init(model_params) return state @@ -259,53 +227,11 @@ def update( """ cfg = self.config optimizer_model_params = self._get_optimizer_model_params(model_params) - print("params in self.optimizer.update", optimizer_model_params) - k_steps = 4 - _acc_update = lambda x, y: x + y - - # Note: we do not enclose variables to allow JAX to re-use memory buffers. - def _do_update(updates, state, params): - acc_grads = jax.tree_util.tree_map( - lambda upd, acc: _acc_update(upd, acc), - updates, - state.acc_grads, - ) - - final_updates, new_inner_state = self.optimizer.update( - acc_grads, state.inner_opt_state, params=params - ) - - emit = state.mini_step == (k_steps - 1) - new_state = MultiStepsState( - mini_step=optax.safe_int32_increment(state.mini_step) % k_steps, - gradient_step=emit - * optax.safe_int32_increment(state.gradient_step) - + (1 - emit) * state.gradient_step, - inner_opt_state=jax.tree_util.tree_map( - lambda st, nst: jnp.where(emit, nst, st), - state.inner_opt_state, - new_inner_state, - ), - acc_grads=jax.tree_util.tree_map( - lambda ga: (1 - emit) * ga, acc_grads - ), - ) - - final_updates = jax.tree_util.tree_map( - lambda ga: emit * ga, final_updates - ) - return final_updates, new_state - - optimizer_parameter_updates, optimizer_state = _do_update( + optimizer_parameter_updates, optimizer_state = self.optimizer.update( gradients, state=self.state["optimizer"], params=optimizer_model_params, ) - # optimizer_parameter_updates, optimizer_state = self.optimizer.update( - # gradients, - # state=self.state["optimizer"], - # params=optimizer_model_params, - # ) self.add_state_update("optimizer", optimizer_state) if cfg.enable_per_variable_summaries: param_rms = jax.tree_util.tree_map( diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index 39e46a200..4f89c14b1 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -708,13 +708,174 @@ def adamw_decoupled_optimizer( return chain(*tx) +# def chain(*args): +# args = [_to_partitioned_transformation(e) for e in args] +# base = optax.chain(*[optax.GradientTransformation(init=e.init, update=e.update) for e in args]) -class MultiStepsState(NamedTuple): - mini_step: int - gradient_step: int - inner_opt_state: Any - acc_grads: Any +# def partition(param_spec): +# return tuple(e.partition(param_spec) for e in args) +# return PartitionedGradientTransformation( +# init=base.init, update=base.update, partition=partition +# ) + +def multisteps_optimizer( + *args +) -> PartitionedGradientTransformation: + base = chain(*args) + k_steps = 4 + + class MultiStepsState(NamedTuple): + mini_step: int + gradient_step: int + inner_opt_state: Any + acc_grads: Any + + def multisteps_init(model_params): + def _copy_zero(model_tree): + return jax.tree_map(lambda x: jnp.full_like(x, 0), model_tree) + state = dict( + optimizer=MultiStepsState( + mini_step = jnp.zeros([], dtype=jnp.int32), + gradient_step = jnp.zeros([], dtype=jnp.int32), + inner_opt_state = base.init(model_params), + acc_grads = _copy_zero(model_params)) + ) + print("Inside multisteps_init state type is", type(state)) + + return state + + def multisteps_partition(optimizer_model_param_specs): + return dict( + optimizer=MultiStepsState( + mini_step = None, + gradient_step = None, + inner_opt_state = base.partition(optimizer_model_param_specs), + acc_grads = optimizer_model_param_specs) # TODO: check if this works was model_param_specs + ) + + # Copied from Optax Multistep + def multisteps_update(updates, state, params): + _acc_update = lambda x, y: x + y + # Note: we do not enclose variables to allow JAX to re-use memory buffers. + acc_grads = jax.tree_util.tree_map( + lambda upd, acc: _acc_update(upd, acc), + updates, + state.acc_grads, + ) + + final_updates, new_inner_state = base.update( + acc_grads, state.inner_opt_state, params=params + ) + + emit = state.mini_step == (k_steps - 1) + new_state = MultiStepsState( + mini_step=optax.safe_int32_increment(state.mini_step) % k_steps, + gradient_step=emit + * optax.safe_int32_increment(state.gradient_step) + + (1 - emit) * state.gradient_step, + inner_opt_state=jax.tree_util.tree_map( + lambda st, nst: jnp.where(emit, nst, st), + state.inner_opt_state, + new_inner_state, + ), + acc_grads=jax.tree_util.tree_map( + lambda ga: (1 - emit) * ga, acc_grads + ), + ) + + final_updates = jax.tree_util.tree_map( + lambda ga: emit * ga, final_updates + ) + return final_updates, new_state + + return PartitionedGradientTransformation( + init=multisteps_init, update=multisteps_update, partition=multisteps_partition + ) +# def multisteps_optimizer( +# learning_rate: float, +# *, +# b1: float, +# b2: float, +# eps: float, +# update_schedule: schedule.Schedule, +# k_steps: int, +# weight_decay: float = 0, +# weight_decay_per_param_scale: Optional[Callable[[NestedOptParam], Any]] = None, +# mu_dtype: Optional[jnp.dtype] = None, +# adam_update_transformation: Optional[ConfigOr[PartitionedGradientTransformation]] = None, +# ) -> PartitionedGradientTransformation: +# breakpoint() + +# chained_transformations = adamw_decoupled_optimizer(learning_rate=learning_rate, b1=b1, b2=b2,eps=eps,update_schedule=update_schedule, weight_decay=weight_decay, weight_decay_per_param_scale=weight_decay_per_param_scale,adam_update_transformation=adam_update_transformation,mu_dtype=mu_dtype) + +# class MultiStepsState(NamedTuple): +# mini_step: int +# gradient_step: int +# inner_opt_state: Any +# acc_grads: Any + +# def multisteps_init(model_params): +# def _copy_zero(model_tree): +# return jax.tree_map(lambda x: jnp.full_like(x, 0), model_tree) +# state = dict( +# optimizer=MultiStepsState( +# mini_step = jnp.zeros([], dtype=jnp.int32), +# gradient_step = jnp.zeros([], dtype=jnp.int32), +# inner_opt_state = chained_transformations.init(model_params), +# acc_grads = _copy_zero(model_params)) +# ) +# print("Inside multisteps_init state type is", type(state)) + +# return state + +# def multisteps_partition(optimizer_model_param_specs): +# return dict( +# optimizer=MultiStepsState( +# mini_step = None, +# gradient_step = None, +# inner_opt_state = chained_transformations.partition(optimizer_model_param_specs), +# acc_grads = optimizer_model_param_specs) # TODO: check if this works was model_param_specs +# ) + +# # Copied from Optax Multistep +# def multisteps_update(updates, state, params): +# _acc_update = lambda x, y: x + y +# # Note: we do not enclose variables to allow JAX to re-use memory buffers. +# acc_grads = jax.tree_util.tree_map( +# lambda upd, acc: _acc_update(upd, acc), +# updates, +# state.acc_grads, +# ) + +# final_updates, new_inner_state = chained_transformations.update( +# acc_grads, state.inner_opt_state, params=params +# ) + +# emit = state.mini_step == (k_steps - 1) +# new_state = MultiStepsState( +# mini_step=optax.safe_int32_increment(state.mini_step) % k_steps, +# gradient_step=emit +# * optax.safe_int32_increment(state.gradient_step) +# + (1 - emit) * state.gradient_step, +# inner_opt_state=jax.tree_util.tree_map( +# lambda st, nst: jnp.where(emit, nst, st), +# state.inner_opt_state, +# new_inner_state, +# ), +# acc_grads=jax.tree_util.tree_map( +# lambda ga: (1 - emit) * ga, acc_grads +# ), +# ) + +# final_updates = jax.tree_util.tree_map( +# lambda ga: emit * ga, final_updates +# ) +# return final_updates, new_state + +# return PartitionedGradientTransformation( +# init=multisteps_init, update=multisteps_update, partition=multisteps_partition +# ) def adam_optimizer( learning_rate: schedule.Schedule, diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 9810849bb..6ce49541e 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -284,7 +284,10 @@ def learner_config( # Decay to this fraction of the peak_lr. alpha=alpha, ) - optimizer_cfg = config_for_function(optimizers.chain).set( + # Change the optimizer requirement to optimizer multisteps instead of chain, + # this can be done through a config file change also by changing key-value pair: + # learner.optimizer.fn: 'axlearn.common.optimizers.chain' + optimizer_cfg = config_for_function(optimizers.multisteps_optimizer).set( args=[ config_for_function(optimizers.clip_by_global_norm).set(max_norm=1), config_for_function(optimizers.adamw_decoupled_optimizer).set( @@ -293,6 +296,7 @@ def learner_config( b2=b2, eps=eps, update_schedule=update_schedule, + # k_steps=4, weight_decay=weight_decay, weight_decay_per_param_scale=None, adam_update_transformation=None, From 541d1a7f0ec1d7c4427c97df5d9fdb1176b5b950 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 25 Mar 2024 19:59:36 +0000 Subject: [PATCH 3/6] added config steps for optax multisteps, everything can be initialized from axlearn model config --- axlearn/common/optimizers.py | 103 +------------------------ axlearn/experiments/text/gpt/common.py | 4 +- 2 files changed, 4 insertions(+), 103 deletions(-) diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index 4f89c14b1..4590c4ccf 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -708,22 +708,10 @@ def adamw_decoupled_optimizer( return chain(*tx) -# def chain(*args): -# args = [_to_partitioned_transformation(e) for e in args] -# base = optax.chain(*[optax.GradientTransformation(init=e.init, update=e.update) for e in args]) - -# def partition(param_spec): -# return tuple(e.partition(param_spec) for e in args) - -# return PartitionedGradientTransformation( -# init=base.init, update=base.update, partition=partition -# ) - def multisteps_optimizer( - *args + *args, k_steps ) -> PartitionedGradientTransformation: base = chain(*args) - k_steps = 4 class MultiStepsState(NamedTuple): mini_step: int @@ -734,16 +722,13 @@ class MultiStepsState(NamedTuple): def multisteps_init(model_params): def _copy_zero(model_tree): return jax.tree_map(lambda x: jnp.full_like(x, 0), model_tree) - state = dict( + return dict( optimizer=MultiStepsState( mini_step = jnp.zeros([], dtype=jnp.int32), gradient_step = jnp.zeros([], dtype=jnp.int32), inner_opt_state = base.init(model_params), acc_grads = _copy_zero(model_params)) ) - print("Inside multisteps_init state type is", type(state)) - - return state def multisteps_partition(optimizer_model_param_specs): return dict( @@ -792,90 +777,6 @@ def multisteps_update(updates, state, params): return PartitionedGradientTransformation( init=multisteps_init, update=multisteps_update, partition=multisteps_partition ) -# def multisteps_optimizer( -# learning_rate: float, -# *, -# b1: float, -# b2: float, -# eps: float, -# update_schedule: schedule.Schedule, -# k_steps: int, -# weight_decay: float = 0, -# weight_decay_per_param_scale: Optional[Callable[[NestedOptParam], Any]] = None, -# mu_dtype: Optional[jnp.dtype] = None, -# adam_update_transformation: Optional[ConfigOr[PartitionedGradientTransformation]] = None, -# ) -> PartitionedGradientTransformation: -# breakpoint() - -# chained_transformations = adamw_decoupled_optimizer(learning_rate=learning_rate, b1=b1, b2=b2,eps=eps,update_schedule=update_schedule, weight_decay=weight_decay, weight_decay_per_param_scale=weight_decay_per_param_scale,adam_update_transformation=adam_update_transformation,mu_dtype=mu_dtype) - -# class MultiStepsState(NamedTuple): -# mini_step: int -# gradient_step: int -# inner_opt_state: Any -# acc_grads: Any - -# def multisteps_init(model_params): -# def _copy_zero(model_tree): -# return jax.tree_map(lambda x: jnp.full_like(x, 0), model_tree) -# state = dict( -# optimizer=MultiStepsState( -# mini_step = jnp.zeros([], dtype=jnp.int32), -# gradient_step = jnp.zeros([], dtype=jnp.int32), -# inner_opt_state = chained_transformations.init(model_params), -# acc_grads = _copy_zero(model_params)) -# ) -# print("Inside multisteps_init state type is", type(state)) - -# return state - -# def multisteps_partition(optimizer_model_param_specs): -# return dict( -# optimizer=MultiStepsState( -# mini_step = None, -# gradient_step = None, -# inner_opt_state = chained_transformations.partition(optimizer_model_param_specs), -# acc_grads = optimizer_model_param_specs) # TODO: check if this works was model_param_specs -# ) - -# # Copied from Optax Multistep -# def multisteps_update(updates, state, params): -# _acc_update = lambda x, y: x + y -# # Note: we do not enclose variables to allow JAX to re-use memory buffers. -# acc_grads = jax.tree_util.tree_map( -# lambda upd, acc: _acc_update(upd, acc), -# updates, -# state.acc_grads, -# ) - -# final_updates, new_inner_state = chained_transformations.update( -# acc_grads, state.inner_opt_state, params=params -# ) - -# emit = state.mini_step == (k_steps - 1) -# new_state = MultiStepsState( -# mini_step=optax.safe_int32_increment(state.mini_step) % k_steps, -# gradient_step=emit -# * optax.safe_int32_increment(state.gradient_step) -# + (1 - emit) * state.gradient_step, -# inner_opt_state=jax.tree_util.tree_map( -# lambda st, nst: jnp.where(emit, nst, st), -# state.inner_opt_state, -# new_inner_state, -# ), -# acc_grads=jax.tree_util.tree_map( -# lambda ga: (1 - emit) * ga, acc_grads -# ), -# ) - -# final_updates = jax.tree_util.tree_map( -# lambda ga: emit * ga, final_updates -# ) -# return final_updates, new_state - -# return PartitionedGradientTransformation( -# init=multisteps_init, update=multisteps_update, partition=multisteps_partition -# ) def adam_optimizer( learning_rate: schedule.Schedule, diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 6ce49541e..a2ab4b323 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -296,13 +296,13 @@ def learner_config( b2=b2, eps=eps, update_schedule=update_schedule, - # k_steps=4, weight_decay=weight_decay, weight_decay_per_param_scale=None, adam_update_transformation=None, mu_dtype=jnp.float32 ), - ] + ], + k_steps=4 ) return learner.Learner.default_config().set(optimizer=optimizer_cfg) From 22c1aa8b7796cf4fa77353ade1f4e1be16e73b29 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 25 Mar 2024 21:00:07 +0000 Subject: [PATCH 4/6] cleanup --- axlearn/common/learner.py | 10 +++++----- axlearn/common/optimizers.py | 4 +--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/axlearn/common/learner.py b/axlearn/common/learner.py index 39e67c096..a19e9c718 100644 --- a/axlearn/common/learner.py +++ b/axlearn/common/learner.py @@ -9,8 +9,7 @@ import dataclasses import enum -from typing import Mapping, Optional, Sequence, Tuple, Any, NamedTuple -# from typing import , Callable, Dict, List, , Optional, Sequence, Tuple, Union +from typing import Mapping, Optional, Sequence, Tuple import jax import optax @@ -164,7 +163,7 @@ def create_state_partition_specs( self, model_param_specs: NestedParameterSpec ) -> NestedPartitionSpec: optimizer_model_param_specs = self._get_optimizer_model_params(model_param_specs) - partition_state = self.optimizer.partition(optimizer_model_param_specs) + partition_state = dict(optimizer=self.optimizer.partition(optimizer_model_param_specs)) if self.config.ema.decay is not None: partition_state["ema"] = self.ema.partition(model_param_specs) return partition_state @@ -180,8 +179,9 @@ def _get_optimizer_model_params(self, model_params: NestedOptParam) -> NestedOpt def init(self, model_params: NestedOptParam) -> NestedTensor: update_types = self._update_types(model_params) register_per_param_settings(update_types, description="learner_update_type") - grad_model_params = self._get_optimizer_model_params(model_params) - state = self.optimizer.init(grad_model_params) + state = dict( + optimizer=self.optimizer.init(self._get_optimizer_model_params(model_params)), + ) if self.config.ema.decay is not None: state["ema"] = self.ema.init(model_params) return state diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index 4590c4ccf..dc151e257 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -685,8 +685,6 @@ def adamw_decoupled_optimizer( A PartitionedGradientTransformation representing a decoupled AdamW optimizer with parameter scaling. """ - # optax.MultiSteps(optimizer, every_k_schedule=3) - # tx = [adam_partition(optax.MultiSteps(optax.scale_by_adam(b1=b1, b2=b2, eps=eps, mu_dtype=mu_dtype), every_k_schedule=2))] tx = [adam_partition(optax.scale_by_adam(b1=b1, b2=b2, eps=eps, mu_dtype=mu_dtype))] if adam_update_transformation is not None: tx.append(maybe_instantiate(adam_update_transformation)) @@ -1593,7 +1591,7 @@ def _init(param: OptParam): ) return _AdastarState( - count=jnp.zeros([], jnp.int32), pps=jax.tree_util.tree_map(_init, params) + count=jnp.zeros((1,), dtype=jnp.int32), pps=jax.tree_util.tree_map(_init, params) ) def update_fn(grads: NestedTensor, state: _AdastarState, params: NestedOptParam): From cbda8b6a37934ab978d494b6db23ad4cc640513f Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 25 Mar 2024 21:01:49 +0000 Subject: [PATCH 5/6] more cleanuo --- axlearn/common/optimizers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index dc151e257..5af8e4d61 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -113,8 +113,7 @@ def init_fn(params: NestedOptParam) -> NestedTensor: def update_fn( updates: optax.Updates, state: optax.OptState, params: NestedOptParam ) -> Tuple[optax.Updates, optax.OptState]: - print("params in update_fn", params) - return base.update(updates, state, opt_param_values(params)) # TODO: apoorvgu params wrong here + return base.update(updates, state, opt_param_values(params)) return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn) @@ -443,7 +442,7 @@ def init_fn(params): count = None else: count = jnp.zeros([], jnp.int32) - return AddDecayedWeightsState(count=count) + return AddDecayedWeightsState(count=jnp.zeros((1,), dtype=jnp.int32)) def update_fn(updates: NestedTensor, state: AddDecayedWeightsState, params: NestedOptParam): if params is None: From 183c27af2fe7da93008d91534e7abd8e8bffa0e6 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 25 Mar 2024 21:06:28 +0000 Subject: [PATCH 6/6] even more cleanup --- axlearn/common/optimizers.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index 5af8e4d61..85df68a07 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -719,22 +719,18 @@ class MultiStepsState(NamedTuple): def multisteps_init(model_params): def _copy_zero(model_tree): return jax.tree_map(lambda x: jnp.full_like(x, 0), model_tree) - return dict( - optimizer=MultiStepsState( + return MultiStepsState( mini_step = jnp.zeros([], dtype=jnp.int32), gradient_step = jnp.zeros([], dtype=jnp.int32), inner_opt_state = base.init(model_params), acc_grads = _copy_zero(model_params)) - ) def multisteps_partition(optimizer_model_param_specs): - return dict( - optimizer=MultiStepsState( + return MultiStepsState( mini_step = None, gradient_step = None, inner_opt_state = base.partition(optimizer_model_param_specs), acc_grads = optimizer_model_param_specs) # TODO: check if this works was model_param_specs - ) # Copied from Optax Multistep def multisteps_update(updates, state, params):