From ae48ccdb9aaedc80bae6adcc4a483abdb7c191af Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 21 Nov 2024 11:56:05 -0500 Subject: [PATCH 01/86] Use jax.jit for sharding initial steps Apply it to the MNIST workload and the Nesterov optimizer. --- algoperf/checkpoint_utils.py | 3 - algoperf/data_utils.py | 5 +- algoperf/sharding_utils.py | 62 +++++++++++++++++++ .../workloads/mnist/mnist_jax/workload.py | 23 ++++--- .../nesterov/jax/submission.py | 62 ++++++++++++++----- 5 files changed, 123 insertions(+), 32 deletions(-) create mode 100644 algoperf/sharding_utils.py diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index f4cb6c2db..26ee5ee59 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -193,10 +193,7 @@ def save_checkpoint(framework: str, train_state, eval_results, global_step, preemption_count). """ if framework == 'jax': - model_params = jax.device_get(jax_utils.unreplicate(model_params)) opt_state, _ = optimizer_state - opt_state = jax.device_get(jax_utils.unreplicate(opt_state)) - model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: if isinstance( model_params, diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..bac9155b5 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -60,10 +60,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - # Reshape (global_batch_size, ...) to - # (local_device_count, per_device_batch_size, ...). - # Assumes that `global_batch_size % local_device_count == 0`. - return x.reshape((local_device_count, -1, *x.shape[1:])) + return x return jax.tree.map(_prepare, batch) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py new file mode 100644 index 000000000..4950f243a --- /dev/null +++ b/algoperf/sharding_utils.py @@ -0,0 +1,62 @@ +"""Utilities for dealing with sharding in JAX.""" + +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec + + +def get_mesh() -> jax.sharding.Mesh: + """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" + return jax.sharding.Mesh(jax.devices(), ("batch",)) + +def get_replicated_sharding(mesh=None): + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + +def get_naive_sharding_spec(mesh=None): + """Returns a sharding spec that shards data along the first axis.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec("batch")) + + +def get_naive_sharding(x, mesh=None): + """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" + if mesh is None: + mesh = get_mesh() + grid_size = mesh.shape["batch"] + if x.shape[0] % grid_size == 0: + return NamedSharding(mesh, PartitionSpec("batch")) + else: + return NamedSharding(mesh, PartitionSpec()) + + +def shard_params(params, mesh=None): + """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" + if mesh is None: + mesh = get_mesh() + return jax.tree_util.tree_map( + lambda x: jax.device_put(x, get_naive_sharding(x)), params + ) + + +def get_sharding_tree(params, mesh=None): + """Returns a sharding tree for a parameter tree.""" + return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) + + +def get_empty_sharding(mesh=None): + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + + +def disp_shard_info(x: jax.Array): + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print( + f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n" + ) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5a4382da1..11429bae6 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -10,9 +10,9 @@ import jax.numpy as jnp import optax -from algoperf import param_utils -from algoperf import spec -from algoperf.workloads.mnist.workload import BaseMnistWorkload +from algorithmic_efficiency import param_utils, sharding_utils +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload class _Model(nn.Module): @@ -46,7 +46,7 @@ def init_model_fn( train=True)['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_1' @@ -101,10 +101,13 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + static_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, @@ -125,11 +128,11 @@ def _eval_model( (jnp.argmax(logits, axis=-1) == batch['targets']) * weights) summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed'] metrics = {'accuracy': accuracy, 'loss': summed_loss} - metrics = lax.psum(metrics, axis_name='batch') return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) + total_metrics = {'accuracy': total_metrics['accuracy'].item() / num_examples, 'loss': total_metrics['loss'].item() / num_examples} + return total_metrics diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 0e53aae42..81f23b171 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -6,10 +6,11 @@ from flax import jax_utils import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp import optax -from algoperf import spec +from algorithmic_efficiency import spec, sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -37,7 +38,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( @@ -87,13 +88,13 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, +# @functools.partial( +# jax.pmap, +# axis_name='batch', +# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), +# static_broadcasted_argnums=(0, 1), +# donate_argnums=(2, 3, 4)) +def train_step(workload, opt_update_fn, model_state, optimizer_state, @@ -124,9 +125,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # # Get correct global mean loss and grad. loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -173,7 +172,40 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, + + mesh = sharding_utils.get_mesh() + # Create shardings for each argument + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # per_device_rngs + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings + ) + outputs = jitted_train_step(workload, opt_update_fn, model_state, optimizer_state, @@ -188,8 +220,8 @@ def update_params( if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state From eb5cac7ba5854e1b88476d46abe6169860211144 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 21 Nov 2024 12:14:09 -0500 Subject: [PATCH 02/86] Use jax.jit for adamw --- .../paper_baselines/adamw/jax/submission.py | 97 ++++++++++++------- 1 file changed, 63 insertions(+), 34 deletions(-) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index dde41fa6d..d017da551 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -6,10 +6,11 @@ from flax import jax_utils import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp import optax -from algoperf import spec +from algoperf import spec, sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -50,24 +51,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -90,9 +85,8 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + + # Compute local loss and gradients loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -105,7 +99,7 @@ def _loss_fn(params): grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + current_param_container) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm @@ -139,23 +133,58 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + + # Set up mesh and sharding + mesh = sharding_utils.get_mesh() + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Define input and output shardings + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings + ) + + outputs = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state From 82977da560637654c93dd7e511be85e2006dcfaa Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 9 Dec 2024 01:53:50 -0500 Subject: [PATCH 03/86] Pass yapf checks --- algoperf/sharding_utils.py | 71 +++++++++---------- .../workloads/mnist/mnist_jax/workload.py | 16 +++-- algoperf/workloads/mnist/workload.py | 7 +- .../paper_baselines/adamw/jax/submission.py | 19 +++-- .../nesterov/jax/submission.py | 59 +++++++-------- 5 files changed, 86 insertions(+), 86 deletions(-) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py index 4950f243a..62a441bc9 100644 --- a/algoperf/sharding_utils.py +++ b/algoperf/sharding_utils.py @@ -5,58 +5,57 @@ def get_mesh() -> jax.sharding.Mesh: - """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" - return jax.sharding.Mesh(jax.devices(), ("batch",)) + """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" + return jax.sharding.Mesh(jax.devices(), ("batch",)) + def get_replicated_sharding(mesh=None): - """Returns a sharding spec that replicates data across all devices.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec()) + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + def get_naive_sharding_spec(mesh=None): - """Returns a sharding spec that shards data along the first axis.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec("batch")) + """Returns a sharding spec that shards data along the first axis.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec("batch")) def get_naive_sharding(x, mesh=None): - """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" - if mesh is None: - mesh = get_mesh() - grid_size = mesh.shape["batch"] - if x.shape[0] % grid_size == 0: - return NamedSharding(mesh, PartitionSpec("batch")) - else: - return NamedSharding(mesh, PartitionSpec()) + """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" + if mesh is None: + mesh = get_mesh() + grid_size = mesh.shape["batch"] + if x.shape[0] % grid_size == 0: + return NamedSharding(mesh, PartitionSpec("batch")) + else: + return NamedSharding(mesh, PartitionSpec()) def shard_params(params, mesh=None): - """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" - if mesh is None: - mesh = get_mesh() - return jax.tree_util.tree_map( - lambda x: jax.device_put(x, get_naive_sharding(x)), params - ) + """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" + if mesh is None: + mesh = get_mesh() + return jax.tree_util.tree_map( + lambda x: jax.device_put(x, get_naive_sharding(x)), params) def get_sharding_tree(params, mesh=None): - """Returns a sharding tree for a parameter tree.""" - return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) + """Returns a sharding tree for a parameter tree.""" + return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) def get_empty_sharding(mesh=None): - """Returns a sharding spec that replicates data across all devices.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec()) + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) def disp_shard_info(x: jax.Array): - """Displays shard info of a jax array.""" - for shard in x.addressable_shards: - print( - f"shard.device: {shard.device}, index: {shard.index}, replica_id:" - f" {shard.replica_id}.\n" - ) + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n") diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 11429bae6..b0a52d77f 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -102,11 +102,12 @@ def loss_fn( @functools.partial( jax.jit, - in_shardings=(sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_spec(), # batch - sharding_utils.get_replicated_sharding(), # model_state - sharding_utils.get_naive_sharding_spec(), # rng - ), + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), static_argnums=(0,)) def _eval_model( self, @@ -134,5 +135,8 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - total_metrics = {'accuracy': total_metrics['accuracy'].item() / num_examples, 'loss': total_metrics['loss'].item() / num_examples} + total_metrics = { + 'accuracy': total_metrics['accuracy'].item() / num_examples, + 'loss': total_metrics['loss'].item() / num_examples + } return total_metrics diff --git a/algoperf/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py index f53aadd0b..c92dd141b 100644 --- a/algoperf/workloads/mnist/workload.py +++ b/algoperf/workloads/mnist/workload.py @@ -46,8 +46,7 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'], - }) + 'targets': x['label'],}) is_train = split == 'train' if cache: @@ -214,8 +213,6 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index d017da551..73f41adbe 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -165,18 +165,17 @@ def update_params( static_argnums=(0, 1), donate_argnums=(2, 3, 4), in_shardings=arg_shardings, - out_shardings=out_shardings - ) + out_shardings=out_shardings) outputs = jitted_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 81f23b171..0c648a156 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -95,14 +95,14 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): # static_broadcasted_argnums=(0, 1), # donate_argnums=(2, 3, 4)) def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -182,20 +182,20 @@ def update_params( arg_shardings = ( # workload is static # opt_update_fn is static - replicated, # model_state - replicated, # optimizer_state - replicated, # current_param_container - sharded, # batch - sharded, # per_device_rngs + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # per_device_rngs replicated, # grad_clip - replicated # label_smoothing + replicated # label_smoothing ) out_shardings = ( - replicated, # new_optimizer_state - replicated, # updated_params - replicated, # new_model_state - replicated, # loss - replicated # grad_norm + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm ) # Jit with shardings jitted_train_step = jax.jit( @@ -203,17 +203,16 @@ def update_params( static_argnums=(0, 1), donate_argnums=(2, 3, 4), in_shardings=arg_shardings, - out_shardings=out_shardings - ) + out_shardings=out_shardings) outputs = jitted_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. @@ -271,6 +270,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 128 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From 99545d4082569d8f98c9314836f8f349709405e4 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 9 Dec 2024 01:58:03 -0500 Subject: [PATCH 04/86] CIFAR workload sharding --- .../cifar/cifar_jax/input_pipeline.py | 4 +- .../workloads/cifar/cifar_jax/workload.py | 153 ++++++++++++------ 2 files changed, 103 insertions(+), 54 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 728d05f29..18cb9ac5b 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -8,7 +8,6 @@ import functools from typing import Dict, Iterator, Tuple -from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds @@ -171,5 +170,6 @@ def create_input_iter( functools.partial( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) - it = jax_utils.prefetch_to_device(it, 2) + # FIXME(rka97): Figure out how to do prefetching+sharding. + # it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index ad43bc62f..7dd883f1e 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -3,7 +3,6 @@ import functools from typing import Any, Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax import linen as nn from flax.core import pop import jax @@ -12,7 +11,7 @@ import optax import tensorflow_datasets as tfds -from algoperf import param_utils +from algoperf import param_utils, sharding_utils from algoperf import spec from algoperf.workloads.cifar.cifar_jax import models from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter @@ -28,15 +27,16 @@ def _build_cifar_dataset( data_dir: str, batch_size: int, cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + repeat_final_dataset: Optional[bool] = None, ) -> Iterator[Dict[str, spec.Tensor]]: - ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) - train = split == 'train' + data_dir = data_dir + "/cifar10" + ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir) + train = split == "train" assert self.num_train_examples + self.num_validation_examples == 50000 - if split in ['train', 'eval_train']: - split = f'train[:{self.num_train_examples}]' - elif split == 'validation': - split = f'train[{self.num_train_examples}:]' + if split in ["train", "eval_train"]: + split = f"train[:{self.num_train_examples}]" + elif split == "validation": + split = f"train[{self.num_train_examples}:]" ds = create_input_iter( split, ds_builder, @@ -48,7 +48,8 @@ def _build_cifar_dataset( self.padding_size, train=train, cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset) + repeat_final_dataset=repeat_final_dataset, + ) return ds def _build_input_queue( @@ -59,7 +60,8 @@ def _build_input_queue( global_batch_size: int, cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches return self._build_cifar_dataset(data_rng, split, @@ -74,34 +76,45 @@ def sync_batch_stats( # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics # and we average them. +<<<<<<< variant A avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') new_model_state = model_state.copy() new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) +>>>>>>> variant B + avg_fn = jax.pmap(lambda x: lax.pmean(x, "x"), "x") + new_model_state = model_state.copy( + {"batch_stats": avg_fn(model_state["batch_stats"])}) +======= end return new_model_state def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate - model_cls = getattr(models, 'ResNet18') + model_cls = getattr(models, "ResNet18") model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) - variables = jax.jit(model.init)({'params': rng}, + variables = jax.jit(model.init)({"params": rng}, jnp.ones(input_shape, model.dtype)) +<<<<<<< variant A model_state, params = pop(variables, 'params') +>>>>>>> variant B + model_state, params = variables.pop("params") +======= end self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + # model_state = jax_utils.replicate(model_state) + # params = jax_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: - return param_key == 'Dense_0' + return param_key == "Dense_0" def model_fn( self, @@ -111,26 +124,38 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, +<<<<<<< variant A use_running_average_bn: Optional[bool] = None +>>>>>>> variant B +======= end ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng - variables = {'params': params, **model_state} + variables = {"params": params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, +<<<<<<< variant A mutable=['batch_stats'], use_running_average_bn=use_running_average_bn) +>>>>>>> variant B + mutable=["batch_stats"], + ) +======= end return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, mutable=False, +<<<<<<< variant A use_running_average_bn=use_running_average_bn) +>>>>>>> variant B + ) +======= end return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in @@ -140,13 +165,15 @@ def loss_fn( label_batch: spec.Tensor, # Dense or one-hot labels. logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ + Return {'summed': scalar summed loss, + 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( @@ -159,51 +186,73 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + "summed": summed_loss, + "n_valid_examples": n_valid_examples, + "per_example": per_example_losses, } def _compute_metrics(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor) -> Dict[str, spec.Tensor]: - summed_loss = self.loss_fn(labels, logits, weights)['summed'] + summed_loss = self.loss_fn(labels, logits, weights)["summed"] # Number of correct predictions. accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) - metrics = { - 'loss': summed_loss, - 'accuracy': accuracy, - } - metrics = lax.psum(metrics, axis_name='batch') - return metrics + return jnp.array(summed_loss), jnp.array(accuracy) - @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) - weights = batch.get('weights') - if weights is None: - weights = jnp.ones(len(logits)) - return self._compute_metrics(logits, batch['targets'], weights) + + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + ) + def _per_device_eval_model( + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) + weights = batch.get("weights") + if weights is None: + weights = jnp.ones(len(logits)) + return self._compute_metrics(logits, batch["targets"], weights) + + losses, accuracies = _per_device_eval_model(params, batch, model_state, rng) + metrics = { + "loss": + jnp.mean(losses, axis=0) if losses.ndim > 0 else losses, + "accuracy": + (jnp.mean(accuracies, axis=0) if accuracies.ndim > 0 else accuracies + ), + } + return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" +<<<<<<< variant A return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) +>>>>>>> variant B + return jax.tree_map(lambda x: x / num_examples, total_metrics) +======= end From 018711a77c43bbf2477d9bfbd8ec1971572db222 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 7 Jan 2025 21:18:44 +0000 Subject: [PATCH 05/86] librispeech_conformer now running Still need to test out (a) output losses, (b) speed, and (c) look into other librispeech. --- .../librispeech_jax/workload.py | 101 +++++++++++++----- .../nesterov/jax/submission.py | 5 +- 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 39012a20d..b2f5a2903 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -7,6 +7,9 @@ import flax.linen as nn import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P + +from algoperf import sharding_utils import jax.numpy as jnp import numpy as np import optax @@ -21,7 +24,6 @@ LibriSpeechDataset from algoperf.workloads.librispeech_conformer.librispeech_jax import models - class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): def __init__(self, @@ -93,8 +95,16 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + + # Add sharding + mesh = sharding_utils.get_mesh() + params = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + params) + model_state = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + model_state) + return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -180,6 +190,7 @@ def _build_input_queue( 'targets': (targets.numpy(), target_paddings.numpy()), } + # Use data_utils.shard_and_maybe_pad_np to handle sharding padded_batch = data_utils.shard_and_maybe_pad_np( numpy_batch, padding_value=1.0) yield padded_batch @@ -305,11 +316,16 @@ def greedy_decode( return hyp, hyp_paddings @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) - def eval_step_pmapped( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_replicated_sharding(), # rng + ), + out_shardings=sharding_utils.get_naive_sharding_spec(), + static_argnums=(0,)) + def _eval_step( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -327,13 +343,39 @@ def eval_step_pmapped( loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) targets, target_paddings = batch['targets'] - return self.metrics_bundle.gather_from_model_output( - loss_dict=loss, - decoded=decoded, - decoded_paddings=decoded_paddings, - targets=targets, - target_paddings=target_paddings, - axis_name='batch') + # Convert metrics bundle to dictionary + metrics_dict = { + 'loss_per_example': loss['per_example'], + 'decoded': decoded, + 'decoded_paddings': decoded_paddings, + 'targets': targets, + 'target_paddings': target_paddings, + 'n_valid_examples': jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] + } + return metrics_dict + + def eval_step( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState): + """Evaluates the model and returns a metrics bundle.""" + metrics_dict = self._eval_step(params, batch, model_state, rng) + + # Convert dictionary back to metrics bundle + metrics = self.metrics_bundle.single_from_model_output( + loss_dict={ + 'summed': metrics_dict['loss_per_example'].sum(), + 'per_example': metrics_dict['loss_per_example'], + 'n_valid_examples': metrics_dict['n_valid_examples'].sum() + }, + decoded=metrics_dict['decoded'], + decoded_paddings=metrics_dict['decoded_paddings'], + targets=metrics_dict['targets'], + target_paddings=metrics_dict['target_paddings']) + + return metrics def _eval_model_on_split(self, split: str, @@ -358,10 +400,10 @@ def _eval_model_on_split(self, metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step_pmapped(params, - eval_batch, - model_state, - rng).unreplicate() + computed_metrics = self.eval_step(params, + eval_batch, + model_state, + rng) if metrics_report is None: metrics_report = computed_metrics @@ -373,15 +415,22 @@ def _eval_model_on_split(self, return computed_metrics + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # model_state + ), + out_shardings=sharding_utils.get_replicated_sharding(), + static_argnums=(0,) + ) def sync_batch_stats( self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state + """Sync batch statistics across replicas.""" + # Replace pmean with direct mean across devices + new_batch_stats = jax.tree_map( + lambda x: jnp.mean(x, axis=0), + model_state['batch_stats']) + return model_state.copy({'batch_stats': new_batch_stats}) class LibriSpeechConformerAttentionTemperatureWorkload( diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 0c648a156..544ba3c13 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -163,7 +163,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -186,7 +185,7 @@ def update_params( replicated, # optimizer_state replicated, # current_param_container sharded, # batch - sharded, # per_device_rngs + replicated, # rngs replicated, # grad_clip replicated # label_smoothing ) @@ -210,7 +209,7 @@ def update_params( optimizer_state, current_param_container, batch, - per_device_rngs, + rng, grad_clip, label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs From fbeb5f1b1fa184c9fcb881058564f856bc2e89c7 Mon Sep 17 00:00:00 2001 From: rka97 Date: Wed, 5 Feb 2025 14:33:39 +0000 Subject: [PATCH 06/86] fix formatting --- algoperf/sharding_utils.py | 4 +- .../librispeech_jax/models.py | 10 ++-- .../librispeech_jax/spectrum_augmenter.py | 4 +- .../librispeech_jax/workload.py | 60 ++++++++++--------- .../librispeech_jax/models.py | 10 ++-- .../workloads/mnist/mnist_jax/workload.py | 3 +- .../paper_baselines/adamw/jax/submission.py | 17 +++--- .../nesterov/jax/submission.py | 9 +-- submission_runner.py | 2 + 9 files changed, 61 insertions(+), 58 deletions(-) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py index 62a441bc9..93a4dd53f 100644 --- a/algoperf/sharding_utils.py +++ b/algoperf/sharding_utils.py @@ -1,7 +1,9 @@ """Utilities for dealing with sharding in JAX.""" import jax -from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec def get_mesh() -> jax.sharding.Mesh: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 593d463c3..2bb527a36 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,12 +442,10 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index 2a6f73d4d..c16740629 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights + < multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index b2f5a2903..d68974cd8 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -7,15 +7,14 @@ import flax.linen as nn import jax from jax import lax -from jax.sharding import NamedSharding, PartitionSpec as P - -from algoperf import sharding_utils import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P import numpy as np import optax import torch from algoperf import data_utils +from algoperf import sharding_utils from algoperf import param_utils from algoperf import spec from algoperf.workloads.librispeech_conformer import metrics @@ -24,6 +23,7 @@ LibriSpeechDataset from algoperf.workloads.librispeech_conformer.librispeech_jax import models + class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): def __init__(self, @@ -99,10 +99,12 @@ def init_model_fn( # Add sharding mesh = sharding_utils.get_mesh() params = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), params) model_state = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), model_state) return params, model_state @@ -345,30 +347,35 @@ def _eval_step( targets, target_paddings = batch['targets'] # Convert metrics bundle to dictionary metrics_dict = { - 'loss_per_example': loss['per_example'], - 'decoded': decoded, - 'decoded_paddings': decoded_paddings, - 'targets': targets, - 'target_paddings': target_paddings, - 'n_valid_examples': jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] + 'loss_per_example': + loss['per_example'], + 'decoded': + decoded, + 'decoded_paddings': + decoded_paddings, + 'targets': + targets, + 'target_paddings': + target_paddings, + 'n_valid_examples': + jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] } return metrics_dict - def eval_step( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState): + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState): """Evaluates the model and returns a metrics bundle.""" metrics_dict = self._eval_step(params, batch, model_state, rng) # Convert dictionary back to metrics bundle metrics = self.metrics_bundle.single_from_model_output( loss_dict={ - 'summed': metrics_dict['loss_per_example'].sum(), - 'per_example': metrics_dict['loss_per_example'], - 'n_valid_examples': metrics_dict['n_valid_examples'].sum() + 'summed': metrics_dict['loss_per_example'].sum(), + 'per_example': metrics_dict['loss_per_example'], + 'n_valid_examples': metrics_dict['n_valid_examples'].sum() }, decoded=metrics_dict['decoded'], decoded_paddings=metrics_dict['decoded_paddings'], @@ -400,10 +407,7 @@ def _eval_model_on_split(self, metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step(params, - eval_batch, - model_state, - rng) + computed_metrics = self.eval_step(params, eval_batch, model_state, rng) if metrics_report is None: metrics_report = computed_metrics @@ -421,15 +425,13 @@ def _eval_model_on_split(self, sharding_utils.get_replicated_sharding(), # model_state ), out_shardings=sharding_utils.get_replicated_sharding(), - static_argnums=(0,) - ) + static_argnums=(0,)) def sync_batch_stats( self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: """Sync batch statistics across replicas.""" # Replace pmean with direct mean across devices - new_batch_stats = jax.tree_map( - lambda x: jnp.mean(x, axis=0), - model_state['batch_stats']) + new_batch_stats = jax.tree_map(lambda x: jnp.mean(x, axis=0), + model_state['batch_stats']) return model_state.copy({'batch_stats': new_batch_stats}) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b116f44cd..c6e7275be 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -139,8 +139,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -273,12 +273,10 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index b0a52d77f..b57dd72dd 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -10,7 +10,8 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import param_utils, sharding_utils +from algorithmic_efficiency import param_utils +from algorithmic_efficiency import sharding_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 73f41adbe..156c7ab20 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -6,8 +6,9 @@ from flax import jax_utils import jax from jax import lax -from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax from algoperf import spec, sharding_utils @@ -124,7 +125,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -146,17 +146,17 @@ def update_params( replicated, # model_state replicated, # optimizer_state replicated, # current_param_container - sharded, # batch - sharded, # rng + sharded, # batch + replicated, # rng replicated, # grad_clip - replicated # label_smoothing + replicated # label_smoothing ) out_shardings = ( replicated, # new_optimizer_state replicated, # updated_params replicated, # new_model_state replicated, # loss - replicated # grad_norm + replicated # grad_norm ) # Jit with shardings @@ -167,16 +167,15 @@ def update_params( in_shardings=arg_shardings, out_shardings=out_shardings) - outputs = jitted_train_step(workload, + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, optimizer_state, current_param_container, batch, - per_device_rngs, + rng, grad_clip, label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 544ba3c13..5b332030b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -6,11 +6,13 @@ from flax import jax_utils import jax from jax import lax -from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax -from algorithmic_efficiency import spec, sharding_utils +from algorithmic_efficiency import sharding_utils +from algorithmic_efficiency import spec _GRAD_CLIP_EPS = 1e-6 @@ -203,7 +205,7 @@ def update_params( donate_argnums=(2, 3, 4), in_shardings=arg_shardings, out_shardings=out_shardings) - outputs = jitted_train_step(workload, + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, optimizer_state, @@ -212,7 +214,6 @@ def update_params( rng, grad_clip, label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: diff --git a/submission_runner.py b/submission_runner.py index a2521e77b..ce61bb581 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,6 +31,8 @@ from absl import logging import jax import tensorflow as tf +jax.config.update('jax_default_prng_impl', 'threefry2x32') +jax.config.update('jax_threefry_partitionable', True) import torch import torch.distributed as dist From 6e4e7b0023c025e68bcbb05ef1e37de97886a8f8 Mon Sep 17 00:00:00 2001 From: rka97 Date: Wed, 5 Feb 2025 14:34:01 +0000 Subject: [PATCH 07/86] shard default --- algoperf/data_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index bac9155b5..6fede4430 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -11,7 +11,7 @@ from torch.utils.data import DistributedSampler from torch.utils.data import Sampler -from algoperf import spec +from algoperf import spec, sharding_utils def shard_and_maybe_pad_np( @@ -60,7 +60,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - return x + return jax.device_put(x, sharding_utils.get_naive_sharding_spec()) # x return jax.tree.map(_prepare, batch) From 4a2c02d3155e28622b24c014774741edd50b4736 Mon Sep 17 00:00:00 2001 From: rka97 Date: Wed, 5 Feb 2025 14:34:08 +0000 Subject: [PATCH 08/86] start imagenet --- algoperf/data_utils.py | 4 +- .../imagenet_jax/input_pipeline.py | 2 +- .../imagenet_resnet/imagenet_jax/workload.py | 50 +++++++++++++------ 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 6fede4430..44eca8c26 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -50,7 +50,7 @@ def shard_and_maybe_pad_np( weights = batch.get('weights') # The weights will also be padded. batch['weights'] = np.ones(mask_shape) if weights is None else weights - + naive_sharding_spec = sharding_utils.get_naive_sharding_spec() def _prepare(x): # Use _numpy() for zero-copy conversion between TF and NumPy. if not isinstance(x, np.ndarray): @@ -60,7 +60,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - return jax.device_put(x, sharding_utils.get_naive_sharding_spec()) # x + return jax.device_put(x, naive_sharding_spec) return jax.tree.map(_prepare, batch) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 66105335b..b1be5ac1f 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -399,6 +399,6 @@ def create_input_iter(split: str, ds) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. - it = jax_utils.prefetch_to_device(it, 2) + # it = jax_utils.prefetch_to_device(it, 2) return iter(it) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 4ec3937b8..b60fd2753 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -19,6 +19,7 @@ import tensorflow_datasets as tfds from algoperf import param_utils +from algoperf import sharding_utils from algoperf import random_utils as prng from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 @@ -71,16 +72,20 @@ def _build_dataset( use_randaug=use_randaug) return ds + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # model_state + ), + out_shardings=sharding_utils.get_replicated_sharding(), + static_argnums=(0,)) def sync_batch_stats( self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() # Create a shallow copy - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state + """Sync batch statistics across replicas.""" + new_batch_stats = jax.tree_map(lambda x: jnp.mean(x, axis=0), + model_state['batch_stats']) + return model_state.copy({'batch_stats': new_batch_stats}) + def init_model_fn( self, @@ -113,18 +118,30 @@ def init_model_fn( model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + mesh = sharding_utils.get_mesh() + params = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), + params) + model_state = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), + model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_replicated_sharding(), # rng + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_naive_sharding_spec()) def _eval_model(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -218,7 +235,7 @@ def _compute_metrics(self, 'loss': summed_loss, 'accuracy': accuracy, } - metrics = lax.psum(metrics, axis_name='batch') + # metrics = lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -252,11 +269,12 @@ def _eval_model_on_split(self, eval_rng = prng.fold_in(eval_rng, bi) step_eval_rngs = prng.split(eval_rng, jax.local_device_count()) batch = next(self._eval_iters[split]) - # We already average these metrics across devices inside _compute_metrics. synced_metrics = self._eval_model(params, batch, model_state, step_eval_rngs) + # Sum up the synced metrics + synced_metrics = jax.tree_map(lambda x: jnp.sum(x, axis=0), synced_metrics) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 From 47beba15f9b6a2d6262ce228e5e4b8b82ae53c70 Mon Sep 17 00:00:00 2001 From: rka97 Date: Wed, 5 Feb 2025 14:34:13 +0000 Subject: [PATCH 09/86] remove bn sync in imagenet (jit handles it automatically) --- .../workloads/imagenet_resnet/imagenet_jax/workload.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index b60fd2753..d7285ab76 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -141,7 +141,7 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: sharding_utils.get_replicated_sharding(), # rng ), static_argnums=(0,), - out_shardings=sharding_utils.get_naive_sharding_spec()) + out_shardings=sharding_utils.get_replicated_sharding()) def _eval_model(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -248,9 +248,6 @@ def _eval_model_on_split(self, data_dir: str, global_step: int = 0) -> Dict[str, float]: del global_step - if model_state is not None: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) num_batches = int(math.ceil(num_examples / global_batch_size)) data_rng, eval_rng = prng.split(rng, 2) # We already repeat the dataset indefinitely in tf.data. @@ -273,14 +270,12 @@ def _eval_model_on_split(self, batch, model_state, step_eval_rngs) - # Sum up the synced metrics - synced_metrics = jax.tree_map(lambda x: jnp.sum(x, axis=0), synced_metrics) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), + eval_metrics = jax.tree.map(lambda x: x / num_examples, eval_metrics) return eval_metrics From 3a18f19c87e0869eceb12894fa4bb7a636a276af Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 6 Feb 2025 18:08:16 +0000 Subject: [PATCH 10/86] ImageNet-ViT also works --- algoperf/sharding_utils.py | 5 +++++ .../imagenet_resnet/imagenet_jax/workload.py | 19 +------------------ .../imagenet_vit/imagenet_jax/workload.py | 7 ++++--- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py index 93a4dd53f..f158b5f9b 100644 --- a/algoperf/sharding_utils.py +++ b/algoperf/sharding_utils.py @@ -17,6 +17,11 @@ def get_replicated_sharding(mesh=None): mesh = get_mesh() return NamedSharding(mesh, PartitionSpec()) +def shard_replicated(x, mesh=None): + """Shards a tensor across all devices.""" + if mesh is None: + mesh = get_mesh() + return jax.tree_map(lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x) def get_naive_sharding_spec(mesh=None): """Returns a sharding spec that shards data along the first axis.""" diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index d7285ab76..a2fc327a1 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -72,21 +72,6 @@ def _build_dataset( use_randaug=use_randaug) return ds - @functools.partial( - jax.jit, - in_shardings=( - sharding_utils.get_replicated_sharding(), # model_state - ), - out_shardings=sharding_utils.get_replicated_sharding(), - static_argnums=(0,)) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync batch statistics across replicas.""" - new_batch_stats = jax.tree_map(lambda x: jnp.mean(x, axis=0), - model_state['batch_stats']) - return model_state.copy({'batch_stats': new_batch_stats}) - - def init_model_fn( self, rng: spec.RandomState, @@ -235,7 +220,6 @@ def _compute_metrics(self, 'loss': summed_loss, 'accuracy': accuracy, } - # metrics = lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -264,12 +248,11 @@ def _eval_model_on_split(self, eval_metrics = {} for bi in range(num_batches): eval_rng = prng.fold_in(eval_rng, bi) - step_eval_rngs = prng.split(eval_rng, jax.local_device_count()) batch = next(self._eval_iters[split]) synced_metrics = self._eval_model(params, batch, model_state, - step_eval_rngs) + eval_rng) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 35a6c46be..22e714a37 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -1,5 +1,6 @@ """ImageNet workload implemented in Jax.""" +import functools from typing import Dict, Optional, Tuple from flax import jax_utils @@ -8,7 +9,7 @@ import jax import jax.numpy as jnp -from algoperf import param_utils +from algoperf import param_utils, sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload @@ -46,8 +47,8 @@ def init_model_fn( params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + params = sharding_utils.shard_replicated(params) + model_state = sharding_utils.shard_replicated(model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: From bd0f5650988a037648944dca9f64cc4bd1918e42 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 20 Feb 2025 15:45:12 +0000 Subject: [PATCH 11/86] Start working on WMT. OOM error --- algoperf/sharding_utils.py | 12 +++- algoperf/workloads/wmt/wmt_jax/workload.py | 67 ++++++++++++++++------ 2 files changed, 60 insertions(+), 19 deletions(-) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py index f158b5f9b..4c330b767 100644 --- a/algoperf/sharding_utils.py +++ b/algoperf/sharding_utils.py @@ -35,12 +35,11 @@ def get_naive_sharding(x, mesh=None): if mesh is None: mesh = get_mesh() grid_size = mesh.shape["batch"] - if x.shape[0] % grid_size == 0: + if len(x.shape) > 0 and x.shape[0] % grid_size == 0: return NamedSharding(mesh, PartitionSpec("batch")) else: return NamedSharding(mesh, PartitionSpec()) - def shard_params(params, mesh=None): """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" if mesh is None: @@ -48,6 +47,15 @@ def shard_params(params, mesh=None): return jax.tree_util.tree_map( lambda x: jax.device_put(x, get_naive_sharding(x)), params) +def shard_naive(x, mesh=None): + return shard_params(x, mesh) + + + +def get_naive_sharding_tree(input_tree, mesh=None): + if mesh is None: + mesh = get_mesh() + return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), input_tree) def get_sharding_tree(params, mesh=None): """Returns a sharding tree for a parameter tree.""" diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cdfcb91df..138da706d 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -13,8 +13,9 @@ import numpy as np import optax -from algoperf import param_utils +from algoperf import param_utils, sharding_utils from algoperf import spec +from algoperf import sharding_utils from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_jax import decode from algoperf.workloads.wmt.wmt_jax import models @@ -69,10 +70,16 @@ def compute_weighted_cross_entropy( } @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) - def eval_step_pmapped( - self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + ), + static_argnums=(0,), # self + ) + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: """Calculate evaluation metrics on a batch.""" inputs = batch['inputs'] targets = batch['targets'] @@ -89,15 +96,23 @@ def eval_step_pmapped( 'accuracy': acc_sum, 'denominator': weight_sum, } +<<<<<<< variant A def eval_step(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: replicated_eval_metrics = self.eval_step_pmapped(params, batch) return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) +>>>>>>> variant B +======= end @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_naive_sharding_spec(), # inputs + ), + static_argnums=(0,2,) + ) def initialize_cache(self, inputs: spec.Tensor, max_decode_len: int = 256) -> Dict[str, spec.Tensor]: @@ -108,11 +123,10 @@ def initialize_cache(self, jax.random.PRNGKey(0), jnp.ones(inputs.shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) + print(initial_variables['cache']) + print(jax.tree_map(lambda x: x.shape, initial_variables['cache'])) return initial_variables['cache'] - # eos_id, max_decode_len are constant. - @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5)) def predict_step(self, inputs: spec.Tensor, params: spec.ParameterContainer, @@ -180,20 +194,34 @@ def translate_and_calculate_bleu(self, """Translates the `predict_ds` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] + jitted_predit_step = None for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - predicted = self.predict_step(pred_batch['inputs'], + if jitted_predit_step is None: + jitted_predict_step = jax.jit( + self.predict_step, + in_shardings=( + sharding_utils.get_naive_sharding_spec(), # inputs + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_tree(cache), # cache + ), + static_argnums=(3, # eos_id + 4, # max_decode_len, + 5, # beam_size + )) + predicted = jitted_predict_step(pred_batch['inputs'], params, cache, decode.EOS_ID, max_predict_length) - predicted = _to_host(predicted) - targets = _to_host(pred_batch['targets']) + # predicted = _to_host(predicted) + # targets = _to_host(pred_batch['targets']) + targets = pred_batch['targets'] # Find actual batch size, ignoring the potential padding. weights = pred_batch.get('weights') if weights is not None: - weights = _to_host(weights) + # weights = _to_host(weights) actual_batch_size = int(weights.sum(0)[0].item()) else: actual_batch_size = len(predicted) @@ -213,7 +241,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is used as attention_dropout_rate.""" - init_fake_batch_size = 2 + init_fake_batch_size = 8 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -235,15 +263,20 @@ def init_model_fn( eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) params_rng, dropout_rng = jax.random.split(rng) + inputs = jnp.ones(input_shape, jnp.float32) + targets = jnp.ones(target_shape, jnp.float32) + sharded_inputs = sharding_utils.shard_naive(inputs) + sharded_targets = sharding_utils.shard_naive(targets) + initial_variables = jax.jit( self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + sharded_inputs, sharded_targets) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + params = sharding_utils.shard_replicated(initial_params) + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'shared_embedding' From 3044efb3a0050b592beaf068848834e624460459 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 20 Feb 2025 17:37:00 +0000 Subject: [PATCH 12/86] post-rebase, still on wmt --- algoperf/workloads/mnist/mnist_jax/workload.py | 8 ++++---- algoperf/workloads/wmt/wmt_jax/workload.py | 11 +---------- .../paper_baselines/nesterov/jax/submission.py | 4 ++-- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index b57dd72dd..38f6e332f 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -10,10 +10,10 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import sharding_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload +from algoperf import param_utils +from algoperf import sharding_utils +from algoperf import spec +from algoperf.workloads.mnist.workload import BaseMnistWorkload class _Model(nn.Module): diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 138da706d..cbcd70f31 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -96,15 +96,6 @@ def eval_step(self, 'accuracy': acc_sum, 'denominator': weight_sum, } -<<<<<<< variant A - - def eval_step(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: - replicated_eval_metrics = self.eval_step_pmapped(params, batch) - return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) ->>>>>>> variant B -======= end @functools.partial( jax.jit, @@ -124,7 +115,7 @@ def initialize_cache(self, jnp.ones(inputs.shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) print(initial_variables['cache']) - print(jax.tree_map(lambda x: x.shape, initial_variables['cache'])) + print(jax.tree.map(lambda x: x.shape, initial_variables['cache'])) return initial_variables['cache'] def predict_step(self, diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 5b332030b..49e46109b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -11,8 +11,8 @@ from jax.sharding import PartitionSpec as P import optax -from algorithmic_efficiency import sharding_utils -from algorithmic_efficiency import spec +from algoperf import sharding_utils +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 From e301c492d271aef132120bd9dbdc9ad606dbae8c Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 20 Feb 2025 18:33:32 +0000 Subject: [PATCH 13/86] cache sharding fix --- algoperf/workloads/wmt/wmt_jax/workload.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cbcd70f31..b162b9e4c 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -110,12 +110,12 @@ def initialize_cache(self, """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] + dummy_inputs = sharding_utils.shard_naive(jnp.ones(inputs.shape, jnp.float32)) + dummy_targets = sharding_utils.shard_naive(jnp.ones(target_shape, jnp.float32)) initial_variables = models.Transformer(config).init( jax.random.PRNGKey(0), - jnp.ones(inputs.shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) - print(initial_variables['cache']) - print(jax.tree.map(lambda x: x.shape, initial_variables['cache'])) + dummy_inputs, + dummy_targets) return initial_variables['cache'] def predict_step(self, @@ -185,11 +185,12 @@ def translate_and_calculate_bleu(self, """Translates the `predict_ds` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] - jitted_predit_step = None + jitted_predict_step = None for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - if jitted_predit_step is None: + cache = sharding_utils.shard_naive(cache) + if jitted_predict_step is None: jitted_predict_step = jax.jit( self.predict_step, in_shardings=( From 4fcf98452d5ec77be4e076515d62b27e5c6d8386 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 21 Feb 2025 03:42:39 +0000 Subject: [PATCH 14/86] target_setting_algorithms sharding, compilation caching compilation caching really speeds things up when doing repeated runs. --- .../target_setting_algorithms/jax_adamw.py | 2 +- .../target_setting_algorithms/jax_momentum.py | 2 +- .../target_setting_algorithms/jax_nadamw.py | 2 +- .../target_setting_algorithms/jax_nesterov.py | 2 +- .../jax_submission_base.py | 71 +++++++++++++------ submission_runner.py | 6 ++ 6 files changed, 58 insertions(+), 27 deletions(-) diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index b64f0dfd6..b99d1fd94 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -41,4 +41,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index a6c3d853b..9da67e8f9 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -41,7 +41,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=False) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 597a43c9e..1ba56bbda 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -168,4 +168,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 0c11044fc..533e23f2c 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -41,7 +41,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 217228935..88718ab67 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -7,26 +7,20 @@ import jax.numpy as jnp import optax -from algoperf import spec +from algoperf import spec, sharding_utils _GRAD_CLIP_EPS = 1e-6 -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -49,9 +43,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # Compute mean loss and grad loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -98,9 +90,42 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - new_optimizer_state, new_params, new_model_state, loss, grad_norm = pmapped_train_step( # pylint: disable=line-too-long + mesh = sharding_utils.get_mesh() + # Create shardings for each argument + replicated = sharding_utils.get_replicated_sharding(mesh) # No partitioning + sharded = sharding_utils.get_naive_sharding_spec(mesh) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step( workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, grad_clip, + current_param_container, batch, rng, grad_clip, label_smoothing) # Log loss, grad_norm. @@ -108,8 +133,8 @@ def update_params( workload.metrics_logger is not None): workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/submission_runner.py b/submission_runner.py index ce61bb581..43d707519 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,8 +31,14 @@ from absl import logging import jax import tensorflow as tf +# New PRNG implementation for correct sharding jax.config.update('jax_default_prng_impl', 'threefry2x32') jax.config.update('jax_threefry_partitionable', True) +# JAX compilation caching +jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) +# jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") import torch import torch.distributed as dist From d147e3914f2495136a6c7b5c020328831afc2589 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 21 Feb 2025 04:16:22 +0000 Subject: [PATCH 15/86] Update tests to correct batch size --- .github/workflows/CI.yml | 20 ++++++++++---------- tests/reference_algorithm_tests.py | 4 +++- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 64ef0302e..ccd99e68d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -37,7 +37,7 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json wmt_pytorch: runs-on: ubuntu-latest steps: @@ -54,7 +54,7 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json imagenet_jax: runs-on: ubuntu-latest steps: @@ -71,8 +71,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json imagenet_pytorch: runs-on: ubuntu-latest steps: @@ -89,8 +89,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json # uncomment when https://github.com/mlcommons/algorithmic-efficiency/issues/339 is resolved. criteo_jax: runs-on: ubuntu-latest @@ -142,8 +142,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json speech_pytorch: runs-on: ubuntu-latest steps: @@ -160,8 +160,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json ogbg: runs-on: ubuntu-latest steps: diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index c4ca514a8..db8928309 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -225,7 +225,7 @@ def _build_input_queue(self, *args, **kwargs): del kwargs np.random.seed(42) - if framework == 'jax' or USE_PYTORCH_DDP: + if USE_PYTORCH_DDP: batch_shape = (n_gpus, global_batch_size // n_gpus) else: batch_shape = (global_batch_size,) @@ -422,6 +422,8 @@ def _test_submission(workload_name, global_batch_size = FLAGS.global_batch_size if FLAGS.global_batch_size < 0: raise ValueError('Must set --global_batch_size.') + elif global_batch_size < n_gpus and FLAGS.framework == 'jax': + raise ValueError('Global batch size cannot be smaller than the number of GPUs when using JAX sharding.') workload = _make_one_batch_workload(workload_class, workload_name, framework, From a2b61bed56f228cd152f0e2bd6de0318366d472a Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 21 Feb 2025 04:48:31 +0000 Subject: [PATCH 16/86] yapf and isort checks.. --- algoperf/checkpoint_utils.py | 1 - algoperf/data_utils.py | 4 +- algoperf/profiler.py | 4 +- algoperf/sharding_utils.py | 24 +++++---- .../workloads/cifar/cifar_jax/workload.py | 31 +---------- .../fastmri/fastmri_pytorch/workload.py | 4 +- .../imagenet_jax/custom_tf_addons.py | 46 ++++++++-------- .../imagenet_jax/randaugment.py | 8 ++- .../imagenet_resnet/imagenet_jax/workload.py | 10 ++-- .../imagenet_pytorch/workload.py | 4 +- .../imagenet_vit/imagenet_jax/workload.py | 5 +- .../librispeech_jax/workload.py | 9 ++-- .../librispeech_pytorch/workload.py | 9 ++-- .../workloads/mnist/mnist_jax/workload.py | 1 - algoperf/workloads/wmt/bleu.py | 3 +- algoperf/workloads/wmt/wmt_jax/workload.py | 54 ++++++++++--------- algoperf/workloads/wmt/wmt_pytorch/models.py | 4 +- .../external_tuning/jax_nadamw_full_budget.py | 10 ++-- .../jax_nadamw_target_setting.py | 10 ++-- .../self_tuning/jax_nadamw_full_budget.py | 10 ++-- .../self_tuning/jax_nadamw_target_setting.py | 10 ++-- .../paper_baselines/adamw/jax/submission.py | 3 +- .../paper_baselines/nadamw/jax/submission.py | 10 ++-- .../paper_baselines/sam/jax/submission.py | 8 +-- .../shampoo/jax/distributed_shampoo.py | 28 ++++------ .../target_setting_algorithms/jax_nadamw.py | 10 ++-- .../jax_submission_base.py | 10 ++-- submission_runner.py | 6 +-- tests/modeldiffs/wmt/compare.py | 6 +-- .../modeldiffs/wmt_attention_temp/compare.py | 6 +-- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 +-- tests/modeldiffs/wmt_post_ln/compare.py | 6 +-- tests/reference_algorithm_tests.py | 4 +- 33 files changed, 172 insertions(+), 192 deletions(-) diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 26ee5ee59..75d8d59ea 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -11,7 +11,6 @@ from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint -import jax import numpy as np from tensorflow.io import gfile # pytype: disable=import-error import torch diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 44eca8c26..abd3f51b3 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -11,7 +11,8 @@ from torch.utils.data import DistributedSampler from torch.utils.data import Sampler -from algoperf import spec, sharding_utils +from algoperf import sharding_utils +from algoperf import spec def shard_and_maybe_pad_np( @@ -51,6 +52,7 @@ def shard_and_maybe_pad_np( # The weights will also be padded. batch['weights'] = np.ones(mask_shape) if weights is None else weights naive_sharding_spec = sharding_utils.get_naive_sharding_spec() + def _prepare(x): # Use _numpy() for zero-copy conversion between TF and NumPy. if not isinstance(x, np.ndarray): diff --git a/algoperf/profiler.py b/algoperf/profiler.py index fa2a1bee2..d73efd964 100644 --- a/algoperf/profiler.py +++ b/algoperf/profiler.py @@ -72,8 +72,8 @@ def _make_report( float(np.std(d)), len(d), float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) for a, - d in self.recorded_durations.items()] + 100.0 * float(np.sum(d)) / total_duration) + for a, d in self.recorded_durations.items()] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py index 4c330b767..fc6b38d4a 100644 --- a/algoperf/sharding_utils.py +++ b/algoperf/sharding_utils.py @@ -1,13 +1,13 @@ """Utilities for dealing with sharding in JAX.""" import jax -from jax.sharding import Mesh from jax.sharding import NamedSharding from jax.sharding import PartitionSpec def get_mesh() -> jax.sharding.Mesh: - """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" + """Creates a mesh from all available GPUs. + Here, we simply create a one-dimensional mesh.""" return jax.sharding.Mesh(jax.devices(), ("batch",)) @@ -17,11 +17,14 @@ def get_replicated_sharding(mesh=None): mesh = get_mesh() return NamedSharding(mesh, PartitionSpec()) + def shard_replicated(x, mesh=None): """Shards a tensor across all devices.""" if mesh is None: mesh = get_mesh() - return jax.tree_map(lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x) + return jax.tree.map( + lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x) + def get_naive_sharding_spec(mesh=None): """Returns a sharding spec that shards data along the first axis.""" @@ -40,26 +43,29 @@ def get_naive_sharding(x, mesh=None): else: return NamedSharding(mesh, PartitionSpec()) + def shard_params(params, mesh=None): - """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" + """Shards a parameter tree across all devices + with naive sharding (see get_naive_sharding).""" if mesh is None: mesh = get_mesh() - return jax.tree_util.tree_map( - lambda x: jax.device_put(x, get_naive_sharding(x)), params) + return jax.tree.map(lambda x: jax.device_put(x, get_naive_sharding(x)), + params) + def shard_naive(x, mesh=None): return shard_params(x, mesh) - def get_naive_sharding_tree(input_tree, mesh=None): if mesh is None: mesh = get_mesh() - return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), input_tree) + return jax.tree.map(lambda x: get_naive_sharding(x, mesh), input_tree) + def get_sharding_tree(params, mesh=None): """Returns a sharding tree for a parameter tree.""" - return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) + return jax.tree.map(lambda x: get_naive_sharding(x, mesh), params) def get_empty_sharding(mesh=None): diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index 7dd883f1e..f827fac87 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -11,7 +11,8 @@ import optax import tensorflow_datasets as tfds -from algoperf import param_utils, sharding_utils +from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.cifar.cifar_jax import models from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter @@ -76,15 +77,9 @@ def sync_batch_stats( # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics # and we average them. -<<<<<<< variant A avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') new_model_state = model_state.copy() new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) ->>>>>>> variant B - avg_fn = jax.pmap(lambda x: lax.pmean(x, "x"), "x") - new_model_state = model_state.copy( - {"batch_stats": avg_fn(model_state["batch_stats"])}) -======= end return new_model_state def init_model_fn( @@ -102,15 +97,9 @@ def init_model_fn( input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)({"params": rng}, jnp.ones(input_shape, model.dtype)) -<<<<<<< variant A model_state, params = pop(variables, 'params') ->>>>>>> variant B - model_state, params = variables.pop("params") -======= end self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - # model_state = jax_utils.replicate(model_state) - # params = jax_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -124,10 +113,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, -<<<<<<< variant A use_running_average_bn: Optional[bool] = None ->>>>>>> variant B -======= end ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng @@ -137,13 +123,8 @@ def model_fn( variables, augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, -<<<<<<< variant A mutable=['batch_stats'], use_running_average_bn=use_running_average_bn) ->>>>>>> variant B - mutable=["batch_stats"], - ) -======= end return logits, new_model_state else: logits = self._model.apply( @@ -151,11 +132,7 @@ def model_fn( augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, mutable=False, -<<<<<<< variant A use_running_average_bn=use_running_average_bn) ->>>>>>> variant B - ) -======= end return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in @@ -251,8 +228,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" -<<<<<<< variant A - return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) ->>>>>>> variant B return jax.tree_map(lambda x: x / num_examples, total_metrics) -======= end diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 58943de2f..216a033d4 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -250,9 +250,7 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index 3d6939218..c9bf154bb 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -20,27 +20,31 @@ tf.dtypes.float64, } -Number = Union[float, - int, - np.float16, - np.float32, - np.float64, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64,] - -TensorLike = Union[List[Union[Number, list]], - tuple, - Number, - np.ndarray, - tf.Tensor, - tf.SparseTensor, - tf.Variable,] +Number = Union[ + float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, +] + +TensorLike = Union[ + List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable, +] def get_ndims(image): diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index c68e2de33..2d7e873c2 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -316,8 +316,7 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), - lambda: im, + tf.equal(step, 0), lambda: im, lambda: tf.gather(build_lut(histo, step), im)) return tf.cast(result, tf.uint8) @@ -552,7 +551,6 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): translate_const=100) image = tf.cond( tf.equal(i, op_to_select), - lambda selected_func=func, - selected_args=args: selected_func(image, *selected_args), - lambda: image) + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args), lambda: image) return image diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index a2fc327a1..6a7cf03be 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -19,8 +19,8 @@ import tensorflow_datasets as tfds from algoperf import param_utils -from algoperf import sharding_utils from algoperf import random_utils as prng +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline @@ -249,17 +249,13 @@ def _eval_model_on_split(self, for bi in range(num_batches): eval_rng = prng.fold_in(eval_rng, bi) batch = next(self._eval_iters[split]) - synced_metrics = self._eval_model(params, - batch, - model_state, - eval_rng) + synced_metrics = self._eval_model(params, batch, model_state, eval_rng) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree.map(lambda x: x / num_examples, - eval_metrics) + eval_metrics = jax.tree.map(lambda x: x / num_examples, eval_metrics) return eval_metrics diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index ed29271f3..bd07edee1 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -307,9 +307,7 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 22e714a37..1a2ce6342 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -1,15 +1,14 @@ """ImageNet workload implemented in Jax.""" -import functools from typing import Dict, Optional, Tuple -from flax import jax_utils from flax import linen as nn from flax.core import pop import jax import jax.numpy as jnp -from algoperf import param_utils, sharding_utils +from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index d68974cd8..69e4351d3 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -2,20 +2,17 @@ import math from typing import Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax.core import pop import flax.linen as nn import jax -from jax import lax import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec as P import numpy as np import optax import torch from algoperf import data_utils -from algoperf import sharding_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.librispeech_conformer import metrics from algoperf.workloads.librispeech_conformer import workload @@ -371,7 +368,7 @@ def eval_step(self, metrics_dict = self._eval_step(params, batch, model_state, rng) # Convert dictionary back to metrics bundle - metrics = self.metrics_bundle.single_from_model_output( + metrics_bundle = self.metrics_bundle.single_from_model_output( loss_dict={ 'summed': metrics_dict['loss_per_example'].sum(), 'per_example': metrics_dict['loss_per_example'], @@ -382,7 +379,7 @@ def eval_step(self, targets=metrics_dict['targets'], target_paddings=metrics_dict['target_paddings']) - return metrics + return metrics_bundle def _eval_model_on_split(self, split: str, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 5ed37957e..1dae389d7 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -259,8 +259,9 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], device=result.device).view( - 1, -1) < result.count_nonzero(dim=1).view(-1, 1) + fin_result.shape[1], + device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( + -1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -328,9 +329,7 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 38f6e332f..15fc6ff89 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -6,7 +6,6 @@ from flax import jax_utils from flax import linen as nn import jax -from jax import lax import jax.numpy as jnp import optax diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py index ad314a7d3..5e175320a 100644 --- a/algoperf/workloads/wmt/bleu.py +++ b/algoperf/workloads/wmt/bleu.py @@ -283,8 +283,7 @@ def ref_stats(output, refs): closest_diff = diff closest_len = reflen elif diff == closest_diff: - if reflen < closest_len: - closest_len = reflen + closest_len = min(reflen, closest_len) ngrams_ref = extract_ngrams(ref) for ngram in ngrams_ref: diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index b162b9e4c..dce08a677 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -13,9 +13,9 @@ import numpy as np import optax -from algoperf import param_utils, sharding_utils -from algoperf import spec +from algoperf import param_utils from algoperf import sharding_utils +from algoperf import spec from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_jax import decode from algoperf.workloads.wmt.wmt_jax import models @@ -72,10 +72,10 @@ def compute_weighted_cross_entropy( @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch ), - static_argnums=(0,), # self + static_argnums=(0,), # self ) def eval_step(self, params: spec.ParameterContainer, @@ -100,22 +100,24 @@ def eval_step(self, @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_naive_sharding_spec(), # inputs + sharding_utils.get_naive_sharding_spec(), # inputs ), - static_argnums=(0,2,) - ) + static_argnums=( + 0, + 2, + )) def initialize_cache(self, inputs: spec.Tensor, max_decode_len: int = 256) -> Dict[str, spec.Tensor]: """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] - dummy_inputs = sharding_utils.shard_naive(jnp.ones(inputs.shape, jnp.float32)) - dummy_targets = sharding_utils.shard_naive(jnp.ones(target_shape, jnp.float32)) + dummy_inputs = sharding_utils.shard_naive( + jnp.ones(inputs.shape, jnp.float32)) + dummy_targets = sharding_utils.shard_naive( + jnp.ones(target_shape, jnp.float32)) initial_variables = models.Transformer(config).init( - jax.random.PRNGKey(0), - dummy_inputs, - dummy_targets) + jax.random.PRNGKey(0), dummy_inputs, dummy_targets) return initial_variables['cache'] def predict_step(self, @@ -194,19 +196,20 @@ def translate_and_calculate_bleu(self, jitted_predict_step = jax.jit( self.predict_step, in_shardings=( - sharding_utils.get_naive_sharding_spec(), # inputs - sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_tree(cache), # cache + sharding_utils.get_naive_sharding_spec(), # inputs + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_tree(cache), # cache ), - static_argnums=(3, # eos_id - 4, # max_decode_len, - 5, # beam_size - )) + static_argnums=( + 3, # eos_id + 4, # max_decode_len, + 5, # beam_size + )) predicted = jitted_predict_step(pred_batch['inputs'], - params, - cache, - decode.EOS_ID, - max_predict_length) + params, + cache, + decode.EOS_ID, + max_predict_length) # predicted = _to_host(predicted) # targets = _to_host(pred_batch['targets']) targets = pred_batch['targets'] @@ -262,7 +265,8 @@ def init_model_fn( initial_variables = jax.jit( self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, - sharded_inputs, sharded_targets) + sharded_inputs, + sharded_targets) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py index a1c7ce15e..089f1bfbb 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models.py @@ -942,8 +942,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) + >= cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index c451a18ac..ade3e9f63 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index b8ac10f33..d3a27411c 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 78c3b5b3e..9bc014ed1 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -132,8 +132,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -148,8 +149,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index ffe854a0e..a6781612c 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -132,8 +132,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -148,8 +149,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 156c7ab20..d4dd0539b 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -11,7 +11,8 @@ from jax.sharding import PartitionSpec as P import optax -from algoperf import spec, sharding_utils +from algoperf import sharding_utils +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index c451a18ac..ade3e9f63 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index b76589705..ce2d1feef 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -67,8 +67,9 @@ def update_fn(updates, state, grad_fn_params_tuple): # the noised parameters in the same order as on the original gradients and # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map( - lambda p, u: p + rho * u, params, updates) + noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, + params, + updates) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), @@ -80,8 +81,7 @@ def update_fn(updates, state, grad_fn_params_tuple): sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) scaled_updates = jax.tree.map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, - lambda _: scaled_updates, + updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, lambda _: updates, None) updates, state = base_opt_update_fn(updates, state, params) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index a5c2732ac..4f670a85b 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1809,17 +1809,13 @@ def sharded_update_fn(grads, state, params): )) new_stats_flat = jax.tree.map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), + lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) outputs = jax.tree.map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), + lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) @@ -1923,8 +1919,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), - errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), errors + >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + @@ -2442,9 +2438,7 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree.map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), + lambda g, s, p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) @@ -2453,9 +2447,7 @@ def update_fn(grads, state, params): params_flat, state.count) outputs = jax.tree.map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), + lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 1ba56bbda..fbb7a612c 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -108,8 +108,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -124,8 +125,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 88718ab67..1cfa5deca 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -7,7 +7,8 @@ import jax.numpy as jnp import optax -from algoperf import spec, sharding_utils +from algoperf import sharding_utils +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -93,7 +94,8 @@ def update_params( mesh = sharding_utils.get_mesh() # Create shardings for each argument replicated = sharding_utils.get_replicated_sharding(mesh) # No partitioning - sharded = sharding_utils.get_naive_sharding_spec(mesh) # Partition along batch dimension + sharded = sharding_utils.get_naive_sharding_spec( + mesh) # Partition along batch dimension # Create the sharding rules for each argument arg_shardings = ( @@ -114,7 +116,7 @@ def update_params( replicated, # loss replicated # grad_norm ) - + # Jit with shardings jitted_train_step = jax.jit( train_step, @@ -122,7 +124,7 @@ def update_params( donate_argnums=(2, 3, 4), in_shardings=arg_shardings, out_shardings=out_shardings) - + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step( workload, opt_update_fn, model_state, optimizer_state, current_param_container, batch, rng, grad_clip, diff --git a/submission_runner.py b/submission_runner.py index 43d707519..eace6054d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,6 +31,7 @@ from absl import logging import jax import tensorflow as tf + # New PRNG implementation for correct sharding jax.config.update('jax_default_prng_impl', 'threefry2x32') jax.config.update('jax_threefry_partitionable', True) @@ -38,7 +39,6 @@ jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) -# jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") import torch import torch.distributed as dist @@ -392,8 +392,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 64401ef7f..656d07fb5 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -75,9 +75,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 01dc2895c..d30474f00 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index 77e71c826..4b0a0b218 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 909fcd672..818eec672 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index db8928309..f576d136b 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -423,7 +423,9 @@ def _test_submission(workload_name, if FLAGS.global_batch_size < 0: raise ValueError('Must set --global_batch_size.') elif global_batch_size < n_gpus and FLAGS.framework == 'jax': - raise ValueError('Global batch size cannot be smaller than the number of GPUs when using JAX sharding.') + raise ValueError( + 'Global batch size cannot be smaller than the number of GPUs when using JAX sharding.' + ) workload = _make_one_batch_workload(workload_class, workload_name, framework, From a80f4ec68a5affbe1386adbe649dd5dcde8310a4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 7 Mar 2025 19:43:02 +0000 Subject: [PATCH 17/86] switch fastmri from pmap to jit --- .../workloads/fastmri/fastmri_jax/workload.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 1156cf30a..cbe961a09 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -10,6 +10,7 @@ from algoperf import param_utils from algoperf import spec +from algoperf import sharding_utils import algoperf.random_utils as prng from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim @@ -39,7 +40,7 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = jax_utils.replicate(params) + params = sharding_utils.shard_replicated(params) return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -94,10 +95,15 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + sharding_utils.get_replicated_sharding() + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding() + ) def _eval_model(self, params: spec.Tensor, batch: Dict[str, spec.Tensor], @@ -126,7 +132,6 @@ def _eval_model(self, 'ssim': ssim_sum, 'loss': summed_loss, } - metrics = jax.lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -154,13 +159,12 @@ def _eval_model_on_split(self, num_batches=num_batches) total_metrics = {'ssim': 0., 'loss': 0.} - eval_rngs = prng.split(model_rng, jax.local_device_count()) for _ in range(num_batches): batch = next(self._eval_iters[split]) # We already sum these metrics across devices inside _eval_model. - synced_metrics = self._eval_model(params, batch, eval_rngs) + synced_metrics = self._eval_model(params, batch, model_rng) total_metrics = { - k: v + synced_metrics[k][0] for k, v in total_metrics.items() + k: v + synced_metrics[k] for k, v in total_metrics.items() } return {k: float(v.item() / num_examples) for k, v in total_metrics.items()} From c39ca51d5b00f7752538a3d3d5e35d5870f3462e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 7 Mar 2025 20:48:09 +0000 Subject: [PATCH 18/86] migrate criteo workload --- .../criteo1tb/criteo1tb_jax/workload.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 91761e458..2f41eb8c6 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -11,6 +11,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.workloads.criteo1tb.criteo1tb_jax import models +from algoperf import sharding_utils from algoperf.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload @@ -105,7 +106,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return sharding_utils.shard_replicated(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_7' @@ -129,11 +130,14 @@ def model_fn( return logits_batch, None @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_batch_pmapped(self, + jax.jit, + in_shardings=(sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding() + ) + def _eval_batch_jitted(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> spec.Tensor: logits, _ = self.model_fn( @@ -157,7 +161,7 @@ def _eval_batch(self, # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. return np.array( - self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) + self._eval_batch_jitted(params, batch), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): From 06377d981b207dbcf8722b8dc36dd4cd474267e6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 7 Mar 2025 20:48:43 +0000 Subject: [PATCH 19/86] update utils function used for sharding conformer --- .../librispeech_conformer/librispeech_jax/workload.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 69e4351d3..cd04d4746 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -94,15 +94,8 @@ def init_model_fn( self._param_types = param_utils.jax_param_types(self._param_shapes) # Add sharding - mesh = sharding_utils.get_mesh() - params = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) - ), - params) - model_state = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) - ), - model_state) + params = sharding_utils.shard_replicated(params) + model_state = sharding_utils.shard_replicated(model_state) return params, model_state From 9cbe7d92e136d22977dce18c61c5712d9b106ba4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 8 Mar 2025 07:21:24 +0000 Subject: [PATCH 20/86] update conformer and deepspeech --- .../librispeech_jax/workload.py | 21 +---------------- .../librispeech_jax/models.py | 4 ---- .../librispeech_jax/workload.py | 5 ++-- .../paper_baselines/adamw/jax/submission.py | 23 ++++++++----------- 4 files changed, 14 insertions(+), 39 deletions(-) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index cd04d4746..8e3f3d975 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -333,7 +333,6 @@ def _eval_step( decoded, decoded_paddings = self.greedy_decode(logits, logit_paddings) loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) - targets, target_paddings = batch['targets'] # Convert metrics bundle to dictionary metrics_dict = { @@ -385,9 +384,6 @@ def _eval_model_on_split(self, global_step: int = 0) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step - if model_state is not None and len(model_state) > 0: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: @@ -409,21 +405,6 @@ def _eval_model_on_split(self, return computed_metrics - @functools.partial( - jax.jit, - in_shardings=( - sharding_utils.get_replicated_sharding(), # model_state - ), - out_shardings=sharding_utils.get_replicated_sharding(), - static_argnums=(0,)) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync batch statistics across replicas.""" - # Replace pmean with direct mean across devices - new_batch_stats = jax.tree_map(lambda x: jnp.mean(x, axis=0), - model_state['batch_stats']) - return model_state.copy({'batch_stats': new_batch_stats}) - class LibriSpeechConformerAttentionTemperatureWorkload( LibriSpeechConformerWorkload): @@ -465,7 +446,7 @@ def use_gelu(self) -> bool: @property def validation_target_value(self) -> float: return 0.094114 - + @property def test_target_value(self) -> float: return 0.056629 diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index c6e7275be..07481272f 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -308,16 +308,12 @@ def __call__(self, inputs, input_paddings=None, train=False): count_v = jnp.sum( jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) - sum_v = jax.lax.psum(sum_v, axis_name='batch') - count_v = jax.lax.psum(count_v, axis_name='batch') - count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v variance = (inputs - mean) * (inputs - mean) * mask sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True) - sum_vv = jax.lax.psum(sum_vv, axis_name='batch') var = sum_vv / count_v self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d3b616f43..392de7b89 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -8,6 +8,7 @@ from algoperf import param_utils from algoperf import spec +from algoperf import sharding_utils from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models @@ -51,8 +52,8 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + model_state = sharding_utils.shard_replicated(model_state) + params = sharding_utils.shard_replicated(params) return params, model_state def model_fn( diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index d4dd0539b..5baa046f5 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -55,16 +55,15 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): return optimizer_state, opt_update_fn - def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -159,15 +158,13 @@ def update_params( replicated, # loss replicated # grad_norm ) - - # Jit with shardings jitted_train_step = jax.jit( train_step, static_argnums=(0, 1), donate_argnums=(2, 3, 4), in_shardings=arg_shardings, - out_shardings=out_shardings) - + out_shardings=out_shardings + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, From c6ecd676da7ebdc63bb43eddff7b08b36b6e2deb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Mar 2025 17:16:31 +0000 Subject: [PATCH 21/86] debugging --- .../librispeech_jax/models.py | 339 +++++++++++++++++- 1 file changed, 335 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 07481272f..5d6be2f5d 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -8,8 +8,10 @@ # webpage : https://bastings.github.io/ """ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Type, Mapping, Sequence +import functools +import flax from flax import linen as nn from flax import struct import jax @@ -418,6 +420,321 @@ def unpack_weights( self.bidirectional, ) +### Swap in regular LSTM layer for debuggin +@jax.vmap +def flip_sequences(inputs: Array, lengths: Array) -> Array: + """Flips a sequence of inputs along the time dimension. + + This function can be used to prepare inputs for the reverse direction of a + bidirectional LSTM. It solves the issue that, when naively flipping multiple + padded sequences stored in a matrix, the first elements would be padding + values for those sequences that were padded. This function keeps the padding + at the end, while flipping the rest of the elements. + + Example: + ```python + inputs = [[1, 0, 0], + [2, 3, 0] + [4, 5, 6]] + lengths = [1, 2, 3] + flip_sequences(inputs, lengths) = [[1, 0, 0], + [3, 2, 0], + [6, 5, 4]] + ``` + + Args: + inputs: An array of input IDs [batch_size, seq_length]. + lengths: The length of each sequence [batch_size]. + + Returns: + An ndarray with the flipped inputs. + """ + # Compute the indices to put the inputs in flipped order as per above example. + max_length = inputs.shape[0] + idxs = (jnp.arange(max_length - 1, -1, -1) + lengths) % max_length + return inputs[idxs] + +class GenericRNNSequenceEncoder(nn.Module): + """Encodes a single sequence using any RNN cell, for example `nn.LSTMCell`. + + The sequence can be encoded left-to-right (default) or right-to-left (by + calling the module with reverse=True). Regardless of encoding direction, + outputs[i, j, ...] is the representation of inputs[i, j, ...]. + + Attributes: + hidden_size: The hidden size of the RNN cell. + cell_type: The RNN cell module to use, for example, `nn.LSTMCell`. + cell_kwargs: Optional keyword arguments for the recurrent cell. + recurrent_dropout_rate: The dropout to apply across time steps. If this is + greater than zero, you must use an RNN cell that implements + `RecurrentDropoutCell` such as RecurrentDropoutOptimizedLSTMCell. + """ + hidden_size: int + cell_type: Type[nn.RNNCellBase] + cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() + recurrent_dropout_rate: float = 0.0 + + def setup(self): + self.cell = self.cell_type(features=self.hidden_size, **self.cell_kwargs) + + @functools.partial( # Repeatedly calls the below method to encode the inputs. + nn.transforms.scan, + variable_broadcast='params', + in_axes=(1, flax.core.axes_scan.broadcast, flax.core.axes_scan.broadcast), + out_axes=1, + split_rngs={'params': False}) + def unroll_cell(self, cell_state: StateType, inputs: Array, + recurrent_dropout_mask: Optional[Array], deterministic: bool): + """Unrolls a recurrent cell over an input sequence. + + Args: + cell_state: The initial cell state, shape: [batch_size, + hidden_size] (or an n-tuple thereof). + inputs: The input sequence. [batch_size, seq_len, input_dim]. + recurrent_dropout_mask: An optional recurrent dropout mask to apply in + between time steps. [batch_size, hidden_size]. + deterministic: Disables recurrent dropout when set to True. + + Returns: + The cell state after processing the complete sequence (including padding), + and a tuple with all intermediate cell states and cell outputs. + """ + # We do not directly scan the cell itself, since it only returns the output. + # This returns both the state and the output, so we can slice out the + # correct final states later. + new_cell_state, output = self.cell(cell_state, inputs) + return new_cell_state, (new_cell_state, output) + + def __call__(self, + inputs: Array, + lengths: Array, + initial_state: StateType, + reverse: bool = False, + deterministic: bool = False): + """Unrolls the RNN cell over the inputs. + + Arguments: + inputs: A batch of sequences. Shape: [batch_size, seq_len, + input_dim]. + lengths: The lengths of the input sequences. + initial_state: The initial state for the RNN cell. Shape: [batch_size, + hidden_size]. + reverse: Process the inputs in reverse order, and reverse the outputs. + This means that the outputs still correspond to the order of the inputs, + but their contexts come from the right, and not from the left. + deterministic: Disables recurrent dropout if set to True. + + Returns: + The encoded sequence of inputs, shaped [batch_size, seq_len, + hidden_size], as well as the final hidden states of the RNN cell. For an + LSTM cell the final states are a tuple (c, h), each shaped [ + batch_size, hidden_size]. + """ + if reverse: + inputs = flip_sequences(inputs, lengths) + + recurrent_dropout_mask = None + _, (cell_states, outputs) = self.unroll_cell(initial_state, inputs, + recurrent_dropout_mask, + deterministic) + final_state = jax.tree.map( + lambda x: x[jnp.arange(inputs.shape[0]), lengths - 1], cell_states) + + if reverse: + outputs = flip_sequences(outputs, lengths) + + return outputs, final_state + + +class GenericRNN(nn.Module): + """Generic RNN class. + + This provides generic RNN functionality to encode sequences with any RNN cell. + The class provides unidirectional and bidirectional layers, and these are + stacked when asking for multiple layers. + + This class be used to create a specific RNN class such as LSTM or GRU. + + Attributes: + cell_type: An RNN cell class to use, e.g., `flax.linen.LSTMCell`. + hidden_size: The size of each recurrent cell. + num_layers: The number of stacked recurrent layers. The output of the first + layer, with optional dropout applied, feeds into the next layer. + dropout_rate: Dropout rate to be applied between LSTM layers. Only applies + when num_layers > 1. + recurrent_dropout_rate: Dropout rate to be applied on the hidden state at + each time step repeating the same dropout mask. + bidirectional: Process the sequence left-to-right and right-to-left and + concatenate the outputs from the two directions. + cell_kwargs: Optional keyword arguments to instantiate the cell with. + """ + cell_type: Type[nn.RNNCellBase] + hidden_size: int + num_layers: int = 1 + dropout_rate: float = 0. + recurrent_dropout_rate: float = 0. + bidirectional: bool = False + cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() + + @nn.compact + def __call__( + self, + inputs: Array, + lengths: Array, + initial_states: Optional[Sequence[StateType]] = None, + deterministic: bool = False + ) -> Tuple[Array, Sequence[StateType]]: + """Processes the input sequence using the recurrent cell. + + Args: + inputs: The input sequence [batch_size, sequence_length, ...] + lengths: The lengths of each sequence in the batch. [batch_size] + initial_states: The initial states for the cells. You must provide + `num_layers` initial states (when using bidirectional, `num_layers * + 2`). + These must be ordered in the following way: (layer_0_forward, + layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, + all initial states will be initialized with zeros. + deterministic: Disables dropout between layers when set to True. + Returns: + The sequence of all outputs for the final layer, and a list of final + states for each cell and direction. Directions are alternated (first + forward, then backward, if bidirectional). For example for a bidirectional + cell this would be: layer 1 forward, layer 1 backward, layer 2 forward, + layer 2 backward, etc.. + For some cells like LSTMCell a state consists of an (c, h) tuple, while + for others cells it only contains a single vector (h,). + """ + batch_size = inputs.shape[0] + final_states = [] + num_directions = 2 if self.bidirectional else 1 + num_cells = self.num_layers * num_directions + + # Construct initial states. + if initial_states is None: # Initialize with zeros. + rng = jax.random.PRNGKey(0) + initial_states = [ + self.cell_type(self.hidden_size).initialize_carry( + rng, (batch_size, 1) + ) + for _ in range(num_cells) + ] + if len(initial_states) != num_cells: + raise ValueError( + f'Please provide {self.num_cells} (`num_layers`, *2 if bidirectional)' + 'initial states.' + ) + + # For each layer, apply the forward and optionally the backward RNN cell. + cell_idx = 0 + for _ in range(self.num_layers): + # Unroll an RNN cell (forward direction) for this layer. + outputs, final_state = GenericRNNSequenceEncoder( + cell_type=self.cell_type, + cell_kwargs=self.cell_kwargs, + hidden_size=self.hidden_size, + recurrent_dropout_rate=self.recurrent_dropout_rate, + name=f'{self.name}SequenceEncoder_{cell_idx}')( + inputs, + lengths, + initial_state=initial_states[cell_idx], + deterministic=deterministic) + final_states.append(final_state) + cell_idx += 1 + + # Unroll an RNN cell (backward direction) for this layer. + if self.bidirectional: + backward_outputs, backward_final_state = GenericRNNSequenceEncoder( + cell_type=self.cell_type, + cell_kwargs=self.cell_kwargs, + hidden_size=self.hidden_size, + recurrent_dropout_rate=self.recurrent_dropout_rate, + name=f'{self.name}SequenceEncoder_{cell_idx}')( + inputs, + lengths, + initial_state=initial_states[cell_idx], + reverse=True, + deterministic=deterministic) + outputs = jnp.concatenate([outputs, backward_outputs], axis=-1) + final_states.append(backward_final_state) + cell_idx += 1 + + inputs = outputs + + return outputs, final_states + + +class LSTM(nn.Module): + """LSTM. + + Attributes: + hidden_size: The size of each recurrent cell. + num_layers: The number of stacked recurrent layers. The output of the first + layer, with optional dropout applied, feeds into the next layer. + dropout_rate: Dropout rate to be applied between LSTM layers. Only applies + when num_layers > 1. + recurrent_dropout_rate: Dropout rate to be applied on the hidden state at + each time step repeating the same dropout mask. + bidirectional: Process the sequence left-to-right and right-to-left and + concatenate the outputs from the two directions. + cell_type: The LSTM cell class to use. Default: + `flax.linen.OptimizedLSTMCell`. If you use hidden_size of >2048, consider + using `flax.linen.LSTMCell` instead, since the optimized LSTM cell works + best for hidden sizes up to 2048. + cell_kwargs: Optional keyword arguments to instantiate the cell with. + """ + hidden_size: int + num_layers: int = 1 + dropout_rate: float = 0. + recurrent_dropout_rate: float = 0. + bidirectional: bool = False + cell_type: Any = nn.OptimizedLSTMCell + cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() + + @nn.compact + def __call__( + self, + inputs: Array, + lengths: Array, + initial_states: Optional[Sequence[StateType]] = None, + deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: + """Processes an input sequence with an LSTM cell. + + Example usage: + ``` + inputs = np.random.normal(size=(2, 3, 4)) + lengths = np.array([1, 3]) + outputs, final_states = LSTM(hidden_size=10).apply(rngs, inputs, lengths) + ``` + + Args: + inputs: The input sequence [batch_size, sequence_length, ...] + lengths: The lengths of each sequence in the batch. [batch_size] + initial_states: The initial states for the cells. You must provide + `num_layers` initial states (when using bidirectional, `num_layers * + 2`). These must be ordered in the following way: (layer_0_forward, + layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, + all initial states will be initialized with zeros. + deterministic: Disables dropout between layers when set to True. + + Returns: + The sequence of all outputs for the final layer, and a list of final + states (h, c) for each cell and direction, ordered first by layer number + and then by direction (first forward, then backward, if bidirectional). + """ + return GenericRNN( + cell_type=self.cell_type, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + dropout_rate=self.dropout_rate, + recurrent_dropout_rate=self.recurrent_dropout_rate, + bidirectional=self.bidirectional, + cell_kwargs=self.cell_kwargs, + name='LSTM')( + inputs, + lengths, + initial_states=initial_states, + deterministic=deterministic) class BatchRNN(nn.Module): """Implements a single deepspeech encoder layer. @@ -437,10 +754,24 @@ def __call__(self, inputs, input_paddings, train): config.batch_norm_epsilon)(inputs, input_paddings, train) - output = CudnnLSTM( - features=config.encoder_dim // 2, + + # For regular LSTM + hidden_size = ( + config.encoder_dim // 2 if config.bidirectional else config.encoder_dim + ) + lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32) + + output, _ = LSTM( + hidden_size=hidden_size, bidirectional=config.bidirectional, - num_layers=1)(inputs, input_paddings) + num_layers=1, + )(inputs, lengths) + + # output = CudnnLSTM( + # features=config.encoder_dim // 2, + # bidirectional=config.bidirectional, + # num_layers=1)(inputs, input_paddings) + return output From f35690de13d415ce00619ad66676bb3cd2bc16a7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 12 Mar 2025 20:03:07 +0000 Subject: [PATCH 22/86] debuging --- .../librispeech_jax/models.py | 38 ++++++++++++++----- .../paper_baselines/adamw/jax/submission.py | 2 +- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 5d6be2f5d..e1d76730d 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -9,6 +9,9 @@ """ from typing import Any, Dict, List, Optional, Tuple, Union, Type, Mapping, Sequence +from absl import logging + +import numpy as np import functools import flax @@ -385,6 +388,23 @@ def __call__( seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32) if use_cuda: + inputs_shape = np.shape(inputs) + h_0_shape = np.shape(h_0) + c_0_shape = np.shape(c_0) + weights_shape = np.shape(weights) + seq_lengths_np = np.shape(seq_lengths) + + n = jax.devices() + logging.info(f"jax num devices {n}") + logging.info(f'inputs shape {inputs_shape}') + logging.info(f'h_0 shape {h_0_shape}') + logging.info(f'c_0 shape {c_0_shape}') + logging.info(f'seq_lengths shape {seq_lengths_np}') + logging.info(f'weights_shape {weights_shape}') + logging.info(f'input_size {input_size}') + logging.info(f'hidden_size {self.features}') + logging.info(f'num_layers {self.num_layers}') + y, h, c = rnn.lstm( x=inputs, h_0=h_0, c_0=c_0, weights=weights, seq_lengths=seq_lengths, input_size=input_size, @@ -761,16 +781,16 @@ def __call__(self, inputs, input_paddings, train): ) lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32) - output, _ = LSTM( - hidden_size=hidden_size, - bidirectional=config.bidirectional, - num_layers=1, - )(inputs, lengths) - - # output = CudnnLSTM( - # features=config.encoder_dim // 2, + # output, _ = LSTM( + # hidden_size=hidden_size, # bidirectional=config.bidirectional, - # num_layers=1)(inputs, input_paddings) + # num_layers=1, + # )(inputs, lengths) + + output = CudnnLSTM( + features=config.encoder_dim // 2, + bidirectional=config.bidirectional, + num_layers=1)(inputs, input_paddings) return output diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 5baa046f5..d381d3dfd 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -223,7 +223,7 @@ def get_batch_size(workload_name): elif workload_name == 'librispeech_conformer': return 256 elif workload_name == 'librispeech_deepspeech': - return 256 + return 32 elif workload_name == 'ogbg': return 512 elif workload_name == 'wmt': From 848b50c7251ed8330510a3eb6853d9acafb6c265 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 18 Mar 2025 19:23:04 +0000 Subject: [PATCH 23/86] reformatting --- tests/modeldiffs/wmt/compare.py | 6 +++--- tests/modeldiffs/wmt_attention_temp/compare.py | 6 +++--- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 +++--- tests/modeldiffs/wmt_post_ln/compare.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 656d07fb5..64401ef7f 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -75,9 +75,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index d30474f00..01dc2895c 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index 4b0a0b218..77e71c826 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 818eec672..909fcd672 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) From fb62eae40f2a5cb00c27118c26423256ebe5cd8e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 18 Mar 2025 19:26:05 +0000 Subject: [PATCH 24/86] reformatting --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index eace6054d..a5c59000b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -392,8 +392,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From fe3f9f021c7f621fea099746951a85d45e550df1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 18 Mar 2025 19:26:45 +0000 Subject: [PATCH 25/86] reformatting --- .../paper_baselines/adamw/jax/submission.py | 20 ++++++------- .../paper_baselines/nadamw/jax/submission.py | 10 +++---- .../paper_baselines/sam/jax/submission.py | 8 +++--- .../shampoo/jax/distributed_shampoo.py | 28 ++++++++++++------- .../target_setting_algorithms/jax_nadamw.py | 10 +++---- 5 files changed, 40 insertions(+), 36 deletions(-) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index d381d3dfd..5bc0644a8 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -55,15 +55,16 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): return optimizer_state, opt_update_fn + def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -163,8 +164,7 @@ def update_params( static_argnums=(0, 1), donate_argnums=(2, 3, 4), in_shardings=arg_shardings, - out_shardings=out_shardings - ) + out_shardings=out_shardings) new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index ade3e9f63..c451a18ac 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index ce2d1feef..b76589705 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -67,9 +67,8 @@ def update_fn(updates, state, grad_fn_params_tuple): # the noised parameters in the same order as on the original gradients and # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, - params, - updates) + noised_params = jax.tree_util.tree_map( + lambda p, u: p + rho * u, params, updates) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), @@ -81,7 +80,8 @@ def update_fn(updates, state, grad_fn_params_tuple): sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) scaled_updates = jax.tree.map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, + updates = jax.lax.cond(updates_norm > grad_clip, + lambda _: scaled_updates, lambda _: updates, None) updates, state = base_opt_update_fn(updates, state, params) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 4f670a85b..a5c2732ac 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1809,13 +1809,17 @@ def sharded_update_fn(grads, state, params): )) new_stats_flat = jax.tree.map( - lambda g, s, p: _compute_stats(g, s, p, state.count), + lambda g, + s, + p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) outputs = jax.tree.map( - lambda g, s, p: _transform_grad(g, s, p, state.count), + lambda g, + s, + p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) @@ -1919,8 +1923,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), errors - >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), + errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + @@ -2438,7 +2442,9 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree.map( - lambda g, s, p: _compute_stats(g, s, p, state.count), + lambda g, + s, + p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) @@ -2447,7 +2453,9 @@ def update_fn(grads, state, params): params_flat, state.count) outputs = jax.tree.map( - lambda g, s, p: _transform_grad(g, s, p, state.count), + lambda g, + s, + p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index fbb7a612c..1ba56bbda 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -108,9 +108,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -125,9 +124,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): From 004afbd541c710cf37ab92546119011ffc36ad28 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 18 Mar 2025 19:27:15 +0000 Subject: [PATCH 26/86] reformatting --- .../external_tuning/jax_nadamw_full_budget.py | 10 ++++------ .../external_tuning/jax_nadamw_target_setting.py | 10 ++++------ .../self_tuning/jax_nadamw_full_budget.py | 10 ++++------ .../self_tuning/jax_nadamw_target_setting.py | 10 ++++------ 4 files changed, 16 insertions(+), 24 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index ade3e9f63..c451a18ac 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index d3a27411c..b8ac10f33 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 9bc014ed1..78c3b5b3e 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index a6781612c..ffe854a0e 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): From f1db3d3e6a7f699ad4e3dcb04b11320f55d10a5c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 18 Mar 2025 19:27:51 +0000 Subject: [PATCH 27/86] reformatting --- algoperf/profiler.py | 4 +- .../criteo1tb/criteo1tb_jax/workload.py | 21 ++++----- .../workloads/fastmri/fastmri_jax/workload.py | 11 ++--- .../fastmri/fastmri_pytorch/workload.py | 4 +- .../imagenet_jax/custom_tf_addons.py | 46 +++++++++---------- .../imagenet_jax/randaugment.py | 8 ++-- .../imagenet_resnet/imagenet_jax/workload.py | 8 ++-- .../imagenet_pytorch/workload.py | 4 +- .../librispeech_jax/models.py | 10 ++-- .../librispeech_jax/spectrum_augmenter.py | 4 +- .../librispeech_jax/workload.py | 2 +- .../librispeech_pytorch/workload.py | 9 ++-- .../librispeech_jax/models.py | 41 +++++++++-------- algoperf/workloads/mnist/workload.py | 7 ++- algoperf/workloads/wmt/wmt_pytorch/models.py | 4 +- 15 files changed, 95 insertions(+), 88 deletions(-) diff --git a/algoperf/profiler.py b/algoperf/profiler.py index d73efd964..fa2a1bee2 100644 --- a/algoperf/profiler.py +++ b/algoperf/profiler.py @@ -72,8 +72,8 @@ def _make_report( float(np.std(d)), len(d), float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) - for a, d in self.recorded_durations.items()] + 100.0 * float(np.sum(d)) / total_duration) for a, + d in self.recorded_durations.items()] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 2f41eb8c6..30ecd2b00 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -130,16 +130,16 @@ def model_fn( return logits_batch, None @functools.partial( - jax.jit, - in_shardings=(sharding_utils.get_replicated_sharding(), - sharding_utils.get_naive_sharding_spec(), - ), - static_argnums=(0,), - out_shardings=sharding_utils.get_replicated_sharding() - ) + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding()) def _eval_batch_jitted(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor]) -> spec.Tensor: logits, _ = self.model_fn( params, batch, @@ -160,8 +160,7 @@ def _eval_batch(self, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return np.array( - self._eval_batch_jitted(params, batch), dtype=np.float64) + return np.array(self._eval_batch_jitted(params, batch), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index cbe961a09..2709ef1e6 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -96,14 +96,11 @@ def loss_fn( @functools.partial( jax.jit, - in_shardings=( - sharding_utils.get_replicated_sharding(), - sharding_utils.get_naive_sharding_spec(), - sharding_utils.get_replicated_sharding() - ), + in_shardings=(sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + sharding_utils.get_replicated_sharding()), static_argnums=(0,), - out_shardings=sharding_utils.get_replicated_sharding() - ) + out_shardings=sharding_utils.get_replicated_sharding()) def _eval_model(self, params: spec.Tensor, batch: Dict[str, spec.Tensor], diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 216a033d4..58943de2f 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -250,7 +250,9 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index c9bf154bb..3d6939218 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -20,31 +20,27 @@ tf.dtypes.float64, } -Number = Union[ - float, - int, - np.float16, - np.float32, - np.float64, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, -] - -TensorLike = Union[ - List[Union[Number, list]], - tuple, - Number, - np.ndarray, - tf.Tensor, - tf.SparseTensor, - tf.Variable, -] +Number = Union[float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64,] + +TensorLike = Union[List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable,] def get_ndims(image): diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 2d7e873c2..c68e2de33 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -316,7 +316,8 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), lambda: im, + tf.equal(step, 0), + lambda: im, lambda: tf.gather(build_lut(histo, step), im)) return tf.cast(result, tf.uint8) @@ -551,6 +552,7 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): translate_const=100) image = tf.cond( tf.equal(i, op_to_select), - lambda selected_func=func, selected_args=args: selected_func( - image, *selected_args), lambda: image) + lambda selected_func=func, + selected_args=args: selected_func(image, *selected_args), + lambda: image) return image diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 6a7cf03be..87b9b82bd 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -105,12 +105,12 @@ def init_model_fn( self._param_types = param_utils.jax_param_types(self._param_shapes) mesh = sharding_utils.get_mesh() params = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) - ), + lambda x: jax.device_put(x, + sharding_utils.get_replicated_sharding(mesh)), params) model_state = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) - ), + lambda x: jax.device_put(x, + sharding_utils.get_replicated_sharding(mesh)), model_state) return params, model_state diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index bd07edee1..ed29271f3 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -307,7 +307,9 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 2bb527a36..593d463c3 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,10 +442,12 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', lambda s: jnp.zeros(s, dtype), + 'mean', + lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', lambda s: jnp.ones(s, dtype), + 'var', + lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index c16740629..2a6f73d4d 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights - < multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights < + multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 8e3f3d975..d1bf29ba4 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -446,7 +446,7 @@ def use_gelu(self) -> bool: @property def validation_target_value(self) -> float: return 0.094114 - + @property def test_target_value(self) -> float: return 0.056629 diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 1dae389d7..5ed37957e 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -259,9 +259,8 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], - device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( - -1, 1) + fin_result.shape[1], device=result.device).view( + 1, -1) < result.count_nonzero(dim=1).view(-1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -329,7 +328,9 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index e1d76730d..0646aac52 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -144,8 +144,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -278,10 +278,12 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', lambda s: jnp.zeros(s, dtype), + 'mean', + lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', lambda s: jnp.ones(s, dtype), + 'var', + lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) @@ -393,7 +395,7 @@ def __call__( c_0_shape = np.shape(c_0) weights_shape = np.shape(weights) seq_lengths_np = np.shape(seq_lengths) - + n = jax.devices() logging.info(f"jax num devices {n}") logging.info(f'inputs shape {inputs_shape}') @@ -440,6 +442,7 @@ def unpack_weights( self.bidirectional, ) + ### Swap in regular LSTM layer for debuggin @jax.vmap def flip_sequences(inputs: Array, lengths: Array) -> Array: @@ -474,6 +477,7 @@ def flip_sequences(inputs: Array, lengths: Array) -> Array: idxs = (jnp.arange(max_length - 1, -1, -1) + lengths) % max_length return inputs[idxs] + class GenericRNNSequenceEncoder(nn.Module): """Encodes a single sequence using any RNN cell, for example `nn.LSTMCell`. @@ -503,8 +507,11 @@ def setup(self): in_axes=(1, flax.core.axes_scan.broadcast, flax.core.axes_scan.broadcast), out_axes=1, split_rngs={'params': False}) - def unroll_cell(self, cell_state: StateType, inputs: Array, - recurrent_dropout_mask: Optional[Array], deterministic: bool): + def unroll_cell(self, + cell_state: StateType, + inputs: Array, + recurrent_dropout_mask: Optional[Array], + deterministic: bool): """Unrolls a recurrent cell over an input sequence. Args: @@ -554,7 +561,8 @@ def __call__(self, inputs = flip_sequences(inputs, lengths) recurrent_dropout_mask = None - _, (cell_states, outputs) = self.unroll_cell(initial_state, inputs, + _, (cell_states, outputs) = self.unroll_cell(initial_state, + inputs, recurrent_dropout_mask, deterministic) final_state = jax.tree.map( @@ -602,8 +610,7 @@ def __call__( inputs: Array, lengths: Array, initial_states: Optional[Sequence[StateType]] = None, - deterministic: bool = False - ) -> Tuple[Array, Sequence[StateType]]: + deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: """Processes the input sequence using the recurrent cell. Args: @@ -635,15 +642,12 @@ def __call__( rng = jax.random.PRNGKey(0) initial_states = [ self.cell_type(self.hidden_size).initialize_carry( - rng, (batch_size, 1) - ) - for _ in range(num_cells) + rng, (batch_size, 1)) for _ in range(num_cells) ] if len(initial_states) != num_cells: raise ValueError( f'Please provide {self.num_cells} (`num_layers`, *2 if bidirectional)' - 'initial states.' - ) + 'initial states.') # For each layer, apply the forward and optionally the backward RNN cell. cell_idx = 0 @@ -756,6 +760,7 @@ def __call__( initial_states=initial_states, deterministic=deterministic) + class BatchRNN(nn.Module): """Implements a single deepspeech encoder layer. """ @@ -775,10 +780,9 @@ def __call__(self, inputs, input_paddings, train): input_paddings, train) - # For regular LSTM + # For regular LSTM hidden_size = ( - config.encoder_dim // 2 if config.bidirectional else config.encoder_dim - ) + config.encoder_dim // 2 if config.bidirectional else config.encoder_dim) lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32) # output, _ = LSTM( @@ -791,7 +795,6 @@ def __call__(self, inputs, input_paddings, train): features=config.encoder_dim // 2, bidirectional=config.bidirectional, num_layers=1)(inputs, input_paddings) - return output diff --git a/algoperf/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py index c92dd141b..f53aadd0b 100644 --- a/algoperf/workloads/mnist/workload.py +++ b/algoperf/workloads/mnist/workload.py @@ -46,7 +46,8 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'],}) + 'targets': x['label'], + }) is_train = split == 'train' if cache: @@ -213,6 +214,8 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py index 089f1bfbb..a1c7ce15e 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models.py @@ -942,8 +942,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) - >= cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) From c208cc7a760b5883c4127c79d457e0d352fd873b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 19 Mar 2025 19:59:51 +0000 Subject: [PATCH 28/86] sharding deepspeech --- .../librispeech_jax/models.py | 18 +++++++++--------- .../librispeech_jax/workload.py | 17 +++++++++++++++++ pyproject.toml | 8 ++------ .../paper_baselines/adamw/jax/submission.py | 2 +- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 0646aac52..1cab9e89e 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -397,15 +397,15 @@ def __call__( seq_lengths_np = np.shape(seq_lengths) n = jax.devices() - logging.info(f"jax num devices {n}") - logging.info(f'inputs shape {inputs_shape}') - logging.info(f'h_0 shape {h_0_shape}') - logging.info(f'c_0 shape {c_0_shape}') - logging.info(f'seq_lengths shape {seq_lengths_np}') - logging.info(f'weights_shape {weights_shape}') - logging.info(f'input_size {input_size}') - logging.info(f'hidden_size {self.features}') - logging.info(f'num_layers {self.num_layers}') + # logging.info(f"jax num devices {n}") + # logging.info(f'inputs shape {inputs_shape}') + # logging.info(f'h_0 shape {h_0_shape}') + # logging.info(f'c_0 shape {c_0_shape}') + # logging.info(f'seq_lengths shape {seq_lengths_np}') + # logging.info(f'weights_shape {weights_shape}') + # logging.info(f'input_size {input_size}') + # logging.info(f'hidden_size {self.features}') + # logging.info(f'num_layers {self.num_layers}') y, h, c = rnn.lstm( x=inputs, h_0=h_0, c_0=c_0, weights=weights, diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 392de7b89..c29e952cb 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -4,6 +4,8 @@ from flax import jax_utils import jax import jax.numpy as jnp +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P import numpy as np from algoperf import param_utils @@ -66,6 +68,21 @@ def model_fn( update_batch_norm: bool, use_running_average_bn: Optional[bool] = None ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + + model_fn_sharded = shard_map(model_fn_ref, + self.mesh, + ) + + def model_fn_ref( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN diff --git a/pyproject.toml b/pyproject.toml index cc404f4b5..f4ebdaee3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,15 +106,11 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", + "jax", "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", - "jax-cuda12-plugin[with_cuda]==0.4.28", - "jax-cuda12-pjrt==0.4.28", + "jax[cuda12]", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 5bc0644a8..8dc71950e 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -74,7 +74,7 @@ def _loss_fn(params): model_state, spec.ForwardPassMode.TRAIN, rng, - update_batch_norm=True) + update_batch_norm=True,) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, From 2e4cc9e9184a371f92af1989bb9a4ac8648d71ce Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 19 Mar 2025 23:15:15 +0000 Subject: [PATCH 29/86] ogbg jit migration --- algoperf/workloads/ogbg/input_pipeline.py | 9 +++- algoperf/workloads/ogbg/ogbg_jax/models.py | 4 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 24 ++++++++--- algoperf/workloads/ogbg/workload.py | 1 + .../paper_baselines/adamw/jax/submission.py | 43 +++++++++---------- submission_runner.py | 5 ++- 6 files changed, 53 insertions(+), 33 deletions(-) diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py index 3cb6f51de..9506a5343 100644 --- a/algoperf/workloads/ogbg/input_pipeline.py +++ b/algoperf/workloads/ogbg/input_pipeline.py @@ -148,10 +148,15 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): weights_shards.append(weights) if count == num_shards: + # yield { + # 'inputs': jraph.batch(graphs_shards), + # 'targets': np.vstack(labels_shards), + # 'weights': np.vstack(weights_shards) + # } def f(x): - return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) - + return jax.tree.map(lambda *vals: np.concatenate(vals, axis=0), x[0], *x[1:]) + graphs_shards = f(graphs_shards) labels_shards = f(labels_shards) weights_shards = f(weights_shards) diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 0e66d2ab8..9607a14e1 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -2,6 +2,7 @@ # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. from typing import Optional, Tuple +import jax from flax import linen as nn import jax.numpy as jnp import jraph @@ -78,7 +79,8 @@ def __call__(self, graph, train): self.hidden_dims, dropout=dropout, activation_fn=activation_fn), update_global_fn=_make_mlp( self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) - + # jax.debug.print(str(graph)) + graph = net(graph) # Map globals to represent the final result diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e895d15a7..347f89721 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -8,6 +8,7 @@ import jraph import optax +from algoperf import sharding_utils from algoperf import param_utils from algoperf import spec from algoperf.workloads.ogbg import metrics @@ -45,7 +46,8 @@ def init_model_fn( params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(params), None + params = sharding_utils.shard_replicated(params) + return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_17' @@ -106,11 +108,20 @@ def _eval_metric(self, labels, logits, masks): return metrics.EvalMetrics.single_from_model_output( loss=loss['per_example'], logits=logits, labels=labels, mask=masks) + # @functools.partial( + # jax.pmap, + # axis_name='batch', + # in_axes=(None, 0, 0, 0, None), + # static_broadcasted_argnums=(0,)) @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + sharding_utils.get_replicated_sharding(), + sharding_utils.get_replicated_sharding()), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding(), + ) def _eval_batch(self, params, batch, model_state, rng): return super()._eval_batch(params, batch, model_state, rng) @@ -119,7 +130,8 @@ def _normalize_eval_metrics( Any]) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples - total_metrics = total_metrics.reduce() + # total_metrics = total_metrics.reduce() + print(total_metrics) return {k: float(v) for k, v in total_metrics.compute().items()} diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 971e7f0f6..14666c081 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -161,6 +161,7 @@ def _eval_batch(self, spec.ForwardPassMode.EVAL, rng, update_batch_norm=False) + jax.debug.print(str(logits)) return self._eval_metric(batch['targets'], logits, batch['weights']) def _eval_model_on_split(self, diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 8dc71950e..6c6d19ef8 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -75,6 +75,7 @@ def _loss_fn(params): spec.ForwardPassMode.TRAIN, rng, update_batch_norm=True,) + jax.debug.print("logits: {logits}", logits=logits) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, @@ -140,31 +141,29 @@ def update_params( replicated = NamedSharding(mesh, P()) # No partitioning sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension - # Define input and output shardings - arg_shardings = ( - # workload is static - # opt_update_fn is static - replicated, # model_state - replicated, # optimizer_state - replicated, # current_param_container - sharded, # batch - replicated, # rng - replicated, # grad_clip - replicated # label_smoothing - ) - out_shardings = ( - replicated, # new_optimizer_state - replicated, # updated_params - replicated, # new_model_state - replicated, # loss - replicated # grad_norm - ) jitted_train_step = jax.jit( train_step, static_argnums=(0, 1), donate_argnums=(2, 3, 4), - in_shardings=arg_shardings, - out_shardings=out_shardings) + in_shardings= ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ), + out_shardings=( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + )) + # print(batch) new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, @@ -176,7 +175,7 @@ def update_params( label_smoothing) # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: + if global_step % 1 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { 'loss': loss.item(), diff --git a/submission_runner.py b/submission_runner.py index a5c59000b..bc2c49b99 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -392,8 +392,9 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if False: + # if ((train_step_end_time - train_state['last_eval_time']) >= + # workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From d3a06fcd8ef9e705144c86c2bce38a7616079c44 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 20 Mar 2025 15:28:12 +0000 Subject: [PATCH 30/86] deepspeech jit changes --- .../librispeech_jax/workload.py | 45 ++++++++++++------- .../paper_baselines/adamw/jax/submission.py | 2 +- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index c29e952cb..c2d56c1cc 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -57,22 +57,7 @@ def init_model_fn( model_state = sharding_utils.shard_replicated(model_state) params = sharding_utils.shard_replicated(params) return params, model_state - - def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None - ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - - model_fn_sharded = shard_map(model_fn_ref, - self.mesh, - ) - + def model_fn_ref( self, params: spec.ParameterContainer, @@ -104,6 +89,34 @@ def model_fn_ref( mutable=False) return (logits, logit_paddings), model_state + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + + model_fn_partial = jax.tree_util.Partial(self.model_fn_ref, + mode=mode, + rng=rng, + update_batch_norm=update_batch_norm, + use_running_average_bn=use_running_average_bn) + + model_fn_sharded = shard_map(model_fn_partial, + sharding_utils.get_mesh(), + in_specs=(None, P('batch'), None), + out_specs=(P('batch'), None), + ) + + model_fn_sharded = model_fn_partial + return model_fn_sharded(params, + augmented_and_preprocessed_input_batch, + model_state,) + def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 6c6d19ef8..fc5d74541 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -222,7 +222,7 @@ def get_batch_size(workload_name): elif workload_name == 'librispeech_conformer': return 256 elif workload_name == 'librispeech_deepspeech': - return 32 + return 256 elif workload_name == 'ogbg': return 512 elif workload_name == 'wmt': From 2cfa2a99bd15bf6a55dcb210a5b118b2440b77f7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 20 Mar 2025 21:20:18 +0000 Subject: [PATCH 31/86] set jax to 0.5.1 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f4ebdaee3..ef161b75c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,11 +106,11 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax", + "jax==0.5.2", "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax[cuda12]", + "jax[cuda12]==0.5.2", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] From 70705a74949f41823a145629d50e8de6ecb6e3e9 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 20 Mar 2025 23:43:37 +0000 Subject: [PATCH 32/86] merge --- .../librispeech_jax/workload.py | 1 + pyproject.toml | 2 +- .../paper_baselines/adamw/jax/submission.py | 3 +- scoring/utils/slurm/make_job_config.py | 118 ++++++++++++++++++ submission_runner.py | 30 ++--- 5 files changed, 137 insertions(+), 17 deletions(-) create mode 100644 scoring/utils/slurm/make_job_config.py diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index c2d56c1cc..5c34cbd74 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -44,6 +44,7 @@ def init_model_fn( fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) + # model_init_fn = functools.partial(self._model.init, train=False) params_rng, dropout_rng = jax.random.split(rng, 2) variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, diff --git a/pyproject.toml b/pyproject.toml index ef161b75c..283ca4d05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] # Frameworks jax_core_deps = [ - "flax==0.8.4", + "flax==0.10.4", "optax==0.2.2", "chex==0.1.86", "ml_dtypes==0.4.1", diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index fc5d74541..10ff82de7 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -75,7 +75,6 @@ def _loss_fn(params): spec.ForwardPassMode.TRAIN, rng, update_batch_norm=True,) - jax.debug.print("logits: {logits}", logits=logits) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, @@ -222,7 +221,7 @@ def get_batch_size(workload_name): elif workload_name == 'librispeech_conformer': return 256 elif workload_name == 'librispeech_deepspeech': - return 256 + return 16 elif workload_name == 'ogbg': return 512 elif workload_name == 'wmt': diff --git a/scoring/utils/slurm/make_job_config.py b/scoring/utils/slurm/make_job_config.py new file mode 100644 index 000000000..20576af66 --- /dev/null +++ b/scoring/utils/slurm/make_job_config.py @@ -0,0 +1,118 @@ +import json +import os + +from absl import app +from absl import flags +import jax + +SUBMISSION_PATH = 'prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py' +EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/baseline' +TUNING_SEARCH_SPACE = None +FRAMEWORK = 'jax' +TUNING_RULESET = 'self' + +flags.DEFINE_string('submission_path', + SUBMISSION_PATH, + 'Path to submission module.') +flags.DEFINE_string('tuning_search_space', + TUNING_SEARCH_SPACE, + 'Path to tuning search space for submission module.') +flags.DEFINE_string('experiment_dir', + EXPERIMENT_DIR, + 'Path to experiment dir where logs will be saved.') +flags.DEFINE_enum( + 'framework', + FRAMEWORK, + enum_values=['jax', 'pytorch'], + help='Can be either pytorch or jax.') +flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.') +flags.DEFINE_enum( + 'tuning_ruleset', + TUNING_RULESET, + enum_values=['external', 'self'], + help='Which tuning ruleset to score this submission on. Can be external or self.' +) + +FLAGS = flags.FLAGS + +MIN_INT = -2**(31) +MAX_INT = 2**(31) - 1 +NUM_TUNING_TRIALS = 5 # For external tuning ruleset +NUM_STUDIES = 3 + +WORKLOADS = { + "imagenet_resnet": {"dataset": "imagenet"}, + "imagenet_vit": {"dataset": "imagenet"}, + "fastmri": {"dataset": "fastmri"}, + "ogbg": {"dataset": "ogbg"}, + "wmt": {"dataset": "wmt"}, + "librispeech_deepspeech": {"dataset": "librispeech"}, + "criteo1tb": {"dataset": "criteo1tb"}, + "librispeech_conformer": {"dataset": "librispeech"} +} + + +def main(_): + workloads = WORKLOADS.keys() + key = jax.random.key(FLAGS.seed) + + jobs = [] + + for workload in workloads: + # Fold in hash(workload) mod(max(uint32)) + workload_key = jax.random.fold_in(key, hash(workload) % (2**32 - 1)) + for study_index in range(NUM_STUDIES): + study_key = jax.random.fold_in(workload_key, study_index) + if FLAGS.tuning_ruleset == 'external': + for hparam_index in range(NUM_TUNING_TRIALS): + run_key = jax.random.fold_in(study_key, hparam_index) + seed = jax.random.randint(run_key, (1,), MIN_INT, MAX_INT)[0].item() + print(seed) + # Add job + job = {} + study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}") + job['framework'] = FLAGS.framework + job['workload'] = workload + job['dataset'] = WORKLOADS[workload]['dataset'] + job['submission_path'] = FLAGS.submission_path + job['experiment_dir'] = study_dir + job['rng_seed'] = seed + job['tuning_ruleset'] = FLAGS.tuning_ruleset + job['num_tuning_trials'] = NUM_TUNING_TRIALS + job['hparam_start_index'] = hparam_index + job['hparam_end_index'] = hparam_index + 1 + job['tuning_search_space'] = FLAGS.tuning_search_space + job['tuning_ruleset'] = FLAGS.tuning_ruleset + jobs.append(job) + print(job) + + else: + run_key = study_key + seed = jax.random.randint(run_key, (1,), MIN_INT, MAX_INT)[0].item() + print(seed) + # Add job + job = {} + study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}") + job['framework'] = FLAGS.framework + job['workload'] = workload + job['dataset'] = WORKLOADS[workload]['dataset'] + job['submission_path'] = FLAGS.submission_path + job['experiment_dir'] = study_dir + job['rng_seed'] = seed + job['tuning_ruleset'] = FLAGS.tuning_ruleset + job['num_tuning_trials'] = 1 + + jobs.append(job) + print(job) + + # Convert job array to dict with job indices + job_dict = {} + for i, job in enumerate(jobs): + job_dict[f"{i}"] = job + + with open('config.json', 'w') as f: + json.dump(job_dict, f, indent=4) + + +if __name__ == '__main__': + app.run(main) diff --git a/submission_runner.py b/submission_runner.py index bc2c49b99..b7fdf117f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -636,20 +636,22 @@ def score_submission_on_workload(workload: spec.Workload, tuning_search_space[hi] = hyperparameters with profiler.profile('Train'): - timing, metrics = train_once(workload, workload_name, - global_batch_size, - global_eval_batch_size, - data_dir, imagenet_v2_data_dir, - init_optimizer_state, - update_params, data_selection, - prepare_for_eval, - hyperparameters, - rng_seed, - rng, - profiler, - max_global_steps, - tuning_dir_name, - save_checkpoints=save_checkpoints,) + with jax.profiler.trace("/logs/tensorboard"): + print('profiling!') + timing, metrics = train_once(workload, workload_name, + global_batch_size, + global_eval_batch_size, + data_dir, imagenet_v2_data_dir, + init_optimizer_state, + update_params, data_selection, + prepare_for_eval, + hyperparameters, + rng_seed, + rng, + profiler, + max_global_steps, + tuning_dir_name, + save_checkpoints=save_checkpoints,) all_timings[hi] = timing all_metrics[hi] = metrics logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}') From 75d63157b8a770e59225b10f2ea5dba48cdfe65d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 1 Apr 2025 22:33:26 +0000 Subject: [PATCH 33/86] upgrade jax to 0.5.3 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 283ca4d05..874297256 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,11 +106,11 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.5.2", + "jax==0.5.3", "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax[cuda12]==0.5.2", + "jax[cuda12]==0.5.3", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] From 1df069021c7d04d7f065ba7fbafff4cb3103cd64 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 1 Apr 2025 22:58:21 +0000 Subject: [PATCH 34/86] change bsz back --- reference_algorithms/paper_baselines/adamw/jax/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 10ff82de7..3423f6e85 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -221,7 +221,7 @@ def get_batch_size(workload_name): elif workload_name == 'librispeech_conformer': return 256 elif workload_name == 'librispeech_deepspeech': - return 16 + return 256 elif workload_name == 'ogbg': return 512 elif workload_name == 'wmt': From c1d0c6689be8236cbc069cef75b21797353f6481 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 02:47:33 +0000 Subject: [PATCH 35/86] formatting --- .github/workflows/CI.yml | 20 +++++++++--------- .../cifar/cifar_jax/input_pipeline.py | 1 + .../workloads/cifar/cifar_jax/workload.py | 21 ++++++++----------- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ccd99e68d..64ef0302e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -37,7 +37,7 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json wmt_pytorch: runs-on: ubuntu-latest steps: @@ -54,7 +54,7 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json imagenet_jax: runs-on: ubuntu-latest steps: @@ -71,8 +71,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json imagenet_pytorch: runs-on: ubuntu-latest steps: @@ -89,8 +89,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json # uncomment when https://github.com/mlcommons/algorithmic-efficiency/issues/339 is resolved. criteo_jax: runs-on: ubuntu-latest @@ -142,8 +142,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json speech_pytorch: runs-on: ubuntu-latest steps: @@ -160,8 +160,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json ogbg: runs-on: ubuntu-latest steps: diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 18cb9ac5b..459ed9266 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -171,5 +171,6 @@ def create_input_iter( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) # FIXME(rka97): Figure out how to do prefetching+sharding. + # TODO (kasimbeg) # it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f827fac87..911812455 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -32,12 +32,12 @@ def _build_cifar_dataset( ) -> Iterator[Dict[str, spec.Tensor]]: data_dir = data_dir + "/cifar10" ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir) - train = split == "train" + train = split == 'train' assert self.num_train_examples + self.num_validation_examples == 50000 - if split in ["train", "eval_train"]: - split = f"train[:{self.num_train_examples}]" - elif split == "validation": - split = f"train[{self.num_train_examples}:]" + if split in ['train', 'eval_train']: + split = f'train[:{self.num_train_examples}]' + elif split == 'validation': + split = f'train[{self.num_train_examples}:]' ds = create_input_iter( split, ds_builder, @@ -49,8 +49,7 @@ def _build_cifar_dataset( self.padding_size, train=train, cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset, - ) + repeat_final_dataset=repeat_final_dataset) return ds def _build_input_queue( @@ -61,8 +60,7 @@ def _build_input_queue( global_batch_size: int, cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None, - ) -> Iterator[Dict[str, spec.Tensor]]: + num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: del num_batches return self._build_cifar_dataset(data_rng, split, @@ -86,12 +84,11 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, - ) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate - model_cls = getattr(models, "ResNet18") + model_cls = getattr(models, 'ResNet18') model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) From 1b9466ca3a637287706a704e2496fe9ba5227598 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 03:02:28 +0000 Subject: [PATCH 36/86] remove debugging statements from submission_runner.py --- .../imagenet_jax/input_pipeline.py | 1 + .../librispeech_jax/models.py | 347 ------------------ submission_runner.py | 39 +- 3 files changed, 17 insertions(+), 370 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index b1be5ac1f..35bc3635c 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -399,6 +399,7 @@ def create_input_iter(split: str, ds) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. + # TODO (kasimbeg): put on device # it = jax_utils.prefetch_to_device(it, 2) return iter(it) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 1cab9e89e..c8f49c830 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -390,23 +390,6 @@ def __call__( seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32) if use_cuda: - inputs_shape = np.shape(inputs) - h_0_shape = np.shape(h_0) - c_0_shape = np.shape(c_0) - weights_shape = np.shape(weights) - seq_lengths_np = np.shape(seq_lengths) - - n = jax.devices() - # logging.info(f"jax num devices {n}") - # logging.info(f'inputs shape {inputs_shape}') - # logging.info(f'h_0 shape {h_0_shape}') - # logging.info(f'c_0 shape {c_0_shape}') - # logging.info(f'seq_lengths shape {seq_lengths_np}') - # logging.info(f'weights_shape {weights_shape}') - # logging.info(f'input_size {input_size}') - # logging.info(f'hidden_size {self.features}') - # logging.info(f'num_layers {self.num_layers}') - y, h, c = rnn.lstm( x=inputs, h_0=h_0, c_0=c_0, weights=weights, seq_lengths=seq_lengths, input_size=input_size, @@ -443,324 +426,6 @@ def unpack_weights( ) -### Swap in regular LSTM layer for debuggin -@jax.vmap -def flip_sequences(inputs: Array, lengths: Array) -> Array: - """Flips a sequence of inputs along the time dimension. - - This function can be used to prepare inputs for the reverse direction of a - bidirectional LSTM. It solves the issue that, when naively flipping multiple - padded sequences stored in a matrix, the first elements would be padding - values for those sequences that were padded. This function keeps the padding - at the end, while flipping the rest of the elements. - - Example: - ```python - inputs = [[1, 0, 0], - [2, 3, 0] - [4, 5, 6]] - lengths = [1, 2, 3] - flip_sequences(inputs, lengths) = [[1, 0, 0], - [3, 2, 0], - [6, 5, 4]] - ``` - - Args: - inputs: An array of input IDs [batch_size, seq_length]. - lengths: The length of each sequence [batch_size]. - - Returns: - An ndarray with the flipped inputs. - """ - # Compute the indices to put the inputs in flipped order as per above example. - max_length = inputs.shape[0] - idxs = (jnp.arange(max_length - 1, -1, -1) + lengths) % max_length - return inputs[idxs] - - -class GenericRNNSequenceEncoder(nn.Module): - """Encodes a single sequence using any RNN cell, for example `nn.LSTMCell`. - - The sequence can be encoded left-to-right (default) or right-to-left (by - calling the module with reverse=True). Regardless of encoding direction, - outputs[i, j, ...] is the representation of inputs[i, j, ...]. - - Attributes: - hidden_size: The hidden size of the RNN cell. - cell_type: The RNN cell module to use, for example, `nn.LSTMCell`. - cell_kwargs: Optional keyword arguments for the recurrent cell. - recurrent_dropout_rate: The dropout to apply across time steps. If this is - greater than zero, you must use an RNN cell that implements - `RecurrentDropoutCell` such as RecurrentDropoutOptimizedLSTMCell. - """ - hidden_size: int - cell_type: Type[nn.RNNCellBase] - cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() - recurrent_dropout_rate: float = 0.0 - - def setup(self): - self.cell = self.cell_type(features=self.hidden_size, **self.cell_kwargs) - - @functools.partial( # Repeatedly calls the below method to encode the inputs. - nn.transforms.scan, - variable_broadcast='params', - in_axes=(1, flax.core.axes_scan.broadcast, flax.core.axes_scan.broadcast), - out_axes=1, - split_rngs={'params': False}) - def unroll_cell(self, - cell_state: StateType, - inputs: Array, - recurrent_dropout_mask: Optional[Array], - deterministic: bool): - """Unrolls a recurrent cell over an input sequence. - - Args: - cell_state: The initial cell state, shape: [batch_size, - hidden_size] (or an n-tuple thereof). - inputs: The input sequence. [batch_size, seq_len, input_dim]. - recurrent_dropout_mask: An optional recurrent dropout mask to apply in - between time steps. [batch_size, hidden_size]. - deterministic: Disables recurrent dropout when set to True. - - Returns: - The cell state after processing the complete sequence (including padding), - and a tuple with all intermediate cell states and cell outputs. - """ - # We do not directly scan the cell itself, since it only returns the output. - # This returns both the state and the output, so we can slice out the - # correct final states later. - new_cell_state, output = self.cell(cell_state, inputs) - return new_cell_state, (new_cell_state, output) - - def __call__(self, - inputs: Array, - lengths: Array, - initial_state: StateType, - reverse: bool = False, - deterministic: bool = False): - """Unrolls the RNN cell over the inputs. - - Arguments: - inputs: A batch of sequences. Shape: [batch_size, seq_len, - input_dim]. - lengths: The lengths of the input sequences. - initial_state: The initial state for the RNN cell. Shape: [batch_size, - hidden_size]. - reverse: Process the inputs in reverse order, and reverse the outputs. - This means that the outputs still correspond to the order of the inputs, - but their contexts come from the right, and not from the left. - deterministic: Disables recurrent dropout if set to True. - - Returns: - The encoded sequence of inputs, shaped [batch_size, seq_len, - hidden_size], as well as the final hidden states of the RNN cell. For an - LSTM cell the final states are a tuple (c, h), each shaped [ - batch_size, hidden_size]. - """ - if reverse: - inputs = flip_sequences(inputs, lengths) - - recurrent_dropout_mask = None - _, (cell_states, outputs) = self.unroll_cell(initial_state, - inputs, - recurrent_dropout_mask, - deterministic) - final_state = jax.tree.map( - lambda x: x[jnp.arange(inputs.shape[0]), lengths - 1], cell_states) - - if reverse: - outputs = flip_sequences(outputs, lengths) - - return outputs, final_state - - -class GenericRNN(nn.Module): - """Generic RNN class. - - This provides generic RNN functionality to encode sequences with any RNN cell. - The class provides unidirectional and bidirectional layers, and these are - stacked when asking for multiple layers. - - This class be used to create a specific RNN class such as LSTM or GRU. - - Attributes: - cell_type: An RNN cell class to use, e.g., `flax.linen.LSTMCell`. - hidden_size: The size of each recurrent cell. - num_layers: The number of stacked recurrent layers. The output of the first - layer, with optional dropout applied, feeds into the next layer. - dropout_rate: Dropout rate to be applied between LSTM layers. Only applies - when num_layers > 1. - recurrent_dropout_rate: Dropout rate to be applied on the hidden state at - each time step repeating the same dropout mask. - bidirectional: Process the sequence left-to-right and right-to-left and - concatenate the outputs from the two directions. - cell_kwargs: Optional keyword arguments to instantiate the cell with. - """ - cell_type: Type[nn.RNNCellBase] - hidden_size: int - num_layers: int = 1 - dropout_rate: float = 0. - recurrent_dropout_rate: float = 0. - bidirectional: bool = False - cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() - - @nn.compact - def __call__( - self, - inputs: Array, - lengths: Array, - initial_states: Optional[Sequence[StateType]] = None, - deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: - """Processes the input sequence using the recurrent cell. - - Args: - inputs: The input sequence [batch_size, sequence_length, ...] - lengths: The lengths of each sequence in the batch. [batch_size] - initial_states: The initial states for the cells. You must provide - `num_layers` initial states (when using bidirectional, `num_layers * - 2`). - These must be ordered in the following way: (layer_0_forward, - layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, - all initial states will be initialized with zeros. - deterministic: Disables dropout between layers when set to True. - Returns: - The sequence of all outputs for the final layer, and a list of final - states for each cell and direction. Directions are alternated (first - forward, then backward, if bidirectional). For example for a bidirectional - cell this would be: layer 1 forward, layer 1 backward, layer 2 forward, - layer 2 backward, etc.. - For some cells like LSTMCell a state consists of an (c, h) tuple, while - for others cells it only contains a single vector (h,). - """ - batch_size = inputs.shape[0] - final_states = [] - num_directions = 2 if self.bidirectional else 1 - num_cells = self.num_layers * num_directions - - # Construct initial states. - if initial_states is None: # Initialize with zeros. - rng = jax.random.PRNGKey(0) - initial_states = [ - self.cell_type(self.hidden_size).initialize_carry( - rng, (batch_size, 1)) for _ in range(num_cells) - ] - if len(initial_states) != num_cells: - raise ValueError( - f'Please provide {self.num_cells} (`num_layers`, *2 if bidirectional)' - 'initial states.') - - # For each layer, apply the forward and optionally the backward RNN cell. - cell_idx = 0 - for _ in range(self.num_layers): - # Unroll an RNN cell (forward direction) for this layer. - outputs, final_state = GenericRNNSequenceEncoder( - cell_type=self.cell_type, - cell_kwargs=self.cell_kwargs, - hidden_size=self.hidden_size, - recurrent_dropout_rate=self.recurrent_dropout_rate, - name=f'{self.name}SequenceEncoder_{cell_idx}')( - inputs, - lengths, - initial_state=initial_states[cell_idx], - deterministic=deterministic) - final_states.append(final_state) - cell_idx += 1 - - # Unroll an RNN cell (backward direction) for this layer. - if self.bidirectional: - backward_outputs, backward_final_state = GenericRNNSequenceEncoder( - cell_type=self.cell_type, - cell_kwargs=self.cell_kwargs, - hidden_size=self.hidden_size, - recurrent_dropout_rate=self.recurrent_dropout_rate, - name=f'{self.name}SequenceEncoder_{cell_idx}')( - inputs, - lengths, - initial_state=initial_states[cell_idx], - reverse=True, - deterministic=deterministic) - outputs = jnp.concatenate([outputs, backward_outputs], axis=-1) - final_states.append(backward_final_state) - cell_idx += 1 - - inputs = outputs - - return outputs, final_states - - -class LSTM(nn.Module): - """LSTM. - - Attributes: - hidden_size: The size of each recurrent cell. - num_layers: The number of stacked recurrent layers. The output of the first - layer, with optional dropout applied, feeds into the next layer. - dropout_rate: Dropout rate to be applied between LSTM layers. Only applies - when num_layers > 1. - recurrent_dropout_rate: Dropout rate to be applied on the hidden state at - each time step repeating the same dropout mask. - bidirectional: Process the sequence left-to-right and right-to-left and - concatenate the outputs from the two directions. - cell_type: The LSTM cell class to use. Default: - `flax.linen.OptimizedLSTMCell`. If you use hidden_size of >2048, consider - using `flax.linen.LSTMCell` instead, since the optimized LSTM cell works - best for hidden sizes up to 2048. - cell_kwargs: Optional keyword arguments to instantiate the cell with. - """ - hidden_size: int - num_layers: int = 1 - dropout_rate: float = 0. - recurrent_dropout_rate: float = 0. - bidirectional: bool = False - cell_type: Any = nn.OptimizedLSTMCell - cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() - - @nn.compact - def __call__( - self, - inputs: Array, - lengths: Array, - initial_states: Optional[Sequence[StateType]] = None, - deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: - """Processes an input sequence with an LSTM cell. - - Example usage: - ``` - inputs = np.random.normal(size=(2, 3, 4)) - lengths = np.array([1, 3]) - outputs, final_states = LSTM(hidden_size=10).apply(rngs, inputs, lengths) - ``` - - Args: - inputs: The input sequence [batch_size, sequence_length, ...] - lengths: The lengths of each sequence in the batch. [batch_size] - initial_states: The initial states for the cells. You must provide - `num_layers` initial states (when using bidirectional, `num_layers * - 2`). These must be ordered in the following way: (layer_0_forward, - layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, - all initial states will be initialized with zeros. - deterministic: Disables dropout between layers when set to True. - - Returns: - The sequence of all outputs for the final layer, and a list of final - states (h, c) for each cell and direction, ordered first by layer number - and then by direction (first forward, then backward, if bidirectional). - """ - return GenericRNN( - cell_type=self.cell_type, - hidden_size=self.hidden_size, - num_layers=self.num_layers, - dropout_rate=self.dropout_rate, - recurrent_dropout_rate=self.recurrent_dropout_rate, - bidirectional=self.bidirectional, - cell_kwargs=self.cell_kwargs, - name='LSTM')( - inputs, - lengths, - initial_states=initial_states, - deterministic=deterministic) - - class BatchRNN(nn.Module): """Implements a single deepspeech encoder layer. """ @@ -779,18 +444,6 @@ def __call__(self, inputs, input_paddings, train): config.batch_norm_epsilon)(inputs, input_paddings, train) - - # For regular LSTM - hidden_size = ( - config.encoder_dim // 2 if config.bidirectional else config.encoder_dim) - lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32) - - # output, _ = LSTM( - # hidden_size=hidden_size, - # bidirectional=config.bidirectional, - # num_layers=1, - # )(inputs, lengths) - output = CudnnLSTM( features=config.encoder_dim // 2, bidirectional=config.bidirectional, diff --git a/submission_runner.py b/submission_runner.py index b7fdf117f..4a8493b2e 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -35,10 +35,6 @@ # New PRNG implementation for correct sharding jax.config.update('jax_default_prng_impl', 'threefry2x32') jax.config.update('jax_threefry_partitionable', True) -# JAX compilation caching -jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") -jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) -jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) import torch import torch.distributed as dist @@ -392,9 +388,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if False: - # if ((train_step_end_time - train_state['last_eval_time']) >= - # workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: @@ -636,22 +631,20 @@ def score_submission_on_workload(workload: spec.Workload, tuning_search_space[hi] = hyperparameters with profiler.profile('Train'): - with jax.profiler.trace("/logs/tensorboard"): - print('profiling!') - timing, metrics = train_once(workload, workload_name, - global_batch_size, - global_eval_batch_size, - data_dir, imagenet_v2_data_dir, - init_optimizer_state, - update_params, data_selection, - prepare_for_eval, - hyperparameters, - rng_seed, - rng, - profiler, - max_global_steps, - tuning_dir_name, - save_checkpoints=save_checkpoints,) + timing, metrics = train_once(workload, workload_name, + global_batch_size, + global_eval_batch_size, + data_dir, imagenet_v2_data_dir, + init_optimizer_state, + update_params, data_selection, + prepare_for_eval, + hyperparameters, + rng_seed, + rng, + profiler, + max_global_steps, + tuning_dir_name, + save_checkpoints=save_checkpoints,) all_timings[hi] = timing all_metrics[hi] = metrics logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}') From 7a71cf0b09c01799153fb5f74fbc68afb0ea2040 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 03:03:24 +0000 Subject: [PATCH 37/86] pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 874297256..489c27fab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,14 +99,14 @@ wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] # Frameworks jax_core_deps = [ - "flax==0.10.4", + "flax==0.8.4", "optax==0.2.2", "chex==0.1.86", "ml_dtypes==0.4.1", "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.5.3", + "jax==0.5.3, "algoperf[jax_core_deps]", ] jax_gpu = [ From a1d0abdf0f074f49f23f17e9f5c469d5238736f6 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 03:08:59 +0000 Subject: [PATCH 38/86] clean up ogbg --- algoperf/workloads/ogbg/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 14666c081..45ea778fd 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -161,7 +161,7 @@ def _eval_batch(self, spec.ForwardPassMode.EVAL, rng, update_batch_norm=False) - jax.debug.print(str(logits)) + # jax.debug.print(str(logits)) return self._eval_metric(batch['targets'], logits, batch['weights']) def _eval_model_on_split(self, From 99caa03de1f71acc06bfaafc5ac1ed59e554bf6e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 03:12:45 +0000 Subject: [PATCH 39/86] clean up mnist workload.py --- algoperf/workloads/mnist/mnist_jax/workload.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 15fc6ff89..30f07bfd4 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -135,8 +135,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - total_metrics = { - 'accuracy': total_metrics['accuracy'].item() / num_examples, - 'loss': total_metrics['loss'].item() / num_examples - } - return total_metrics + return jax.tree.map(lambda: float(x.item() / num_examples), total_metrics) From b14174b0c6379a537c721d52147c7af06fafa7f2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 21:00:12 +0000 Subject: [PATCH 40/86] refactoring & clean up --- algoperf/data_utils.py | 5 +- algoperf/sharding_utils.py | 82 ------------------- .../workloads/cifar/cifar_jax/workload.py | 10 +-- .../criteo1tb/criteo1tb_jax/workload.py | 10 +-- .../workloads/fastmri/fastmri_jax/workload.py | 12 +-- .../imagenet_resnet/imagenet_jax/workload.py | 18 ++-- .../imagenet_vit/imagenet_jax/workload.py | 6 +- .../librispeech_jax/workload.py | 16 ++-- .../librispeech_jax/workload.py | 8 +- .../workloads/mnist/mnist_jax/workload.py | 10 +-- algoperf/workloads/ogbg/ogbg_jax/workload.py | 14 ++-- algoperf/workloads/wmt/wmt_jax/workload.py | 26 +++--- .../paper_baselines/adamw/jax/submission.py | 4 +- .../nesterov/jax/submission.py | 4 +- .../jax_submission_base.py | 8 +- 15 files changed, 75 insertions(+), 158 deletions(-) delete mode 100644 algoperf/sharding_utils.py diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index abd3f51b3..72c0043e5 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -11,7 +11,7 @@ from torch.utils.data import DistributedSampler from torch.utils.data import Sampler -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec @@ -51,7 +51,6 @@ def shard_and_maybe_pad_np( weights = batch.get('weights') # The weights will also be padded. batch['weights'] = np.ones(mask_shape) if weights is None else weights - naive_sharding_spec = sharding_utils.get_naive_sharding_spec() def _prepare(x): # Use _numpy() for zero-copy conversion between TF and NumPy. @@ -62,7 +61,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - return jax.device_put(x, naive_sharding_spec) + return jax.device_put(x, jax.sharding_utils.get_batch_dim_sharding()) return jax.tree.map(_prepare, batch) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py deleted file mode 100644 index fc6b38d4a..000000000 --- a/algoperf/sharding_utils.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Utilities for dealing with sharding in JAX.""" - -import jax -from jax.sharding import NamedSharding -from jax.sharding import PartitionSpec - - -def get_mesh() -> jax.sharding.Mesh: - """Creates a mesh from all available GPUs. - Here, we simply create a one-dimensional mesh.""" - return jax.sharding.Mesh(jax.devices(), ("batch",)) - - -def get_replicated_sharding(mesh=None): - """Returns a sharding spec that replicates data across all devices.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec()) - - -def shard_replicated(x, mesh=None): - """Shards a tensor across all devices.""" - if mesh is None: - mesh = get_mesh() - return jax.tree.map( - lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x) - - -def get_naive_sharding_spec(mesh=None): - """Returns a sharding spec that shards data along the first axis.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec("batch")) - - -def get_naive_sharding(x, mesh=None): - """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" - if mesh is None: - mesh = get_mesh() - grid_size = mesh.shape["batch"] - if len(x.shape) > 0 and x.shape[0] % grid_size == 0: - return NamedSharding(mesh, PartitionSpec("batch")) - else: - return NamedSharding(mesh, PartitionSpec()) - - -def shard_params(params, mesh=None): - """Shards a parameter tree across all devices - with naive sharding (see get_naive_sharding).""" - if mesh is None: - mesh = get_mesh() - return jax.tree.map(lambda x: jax.device_put(x, get_naive_sharding(x)), - params) - - -def shard_naive(x, mesh=None): - return shard_params(x, mesh) - - -def get_naive_sharding_tree(input_tree, mesh=None): - if mesh is None: - mesh = get_mesh() - return jax.tree.map(lambda x: get_naive_sharding(x, mesh), input_tree) - - -def get_sharding_tree(params, mesh=None): - """Returns a sharding tree for a parameter tree.""" - return jax.tree.map(lambda x: get_naive_sharding(x, mesh), params) - - -def get_empty_sharding(mesh=None): - """Returns a sharding spec that replicates data across all devices.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec()) - - -def disp_shard_info(x: jax.Array): - """Displays shard info of a jax array.""" - for shard in x.addressable_shards: - print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" - f" {shard.replica_id}.\n") diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index 911812455..5248eeec4 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -12,7 +12,7 @@ import tensorflow_datasets as tfds from algoperf import param_utils -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.cifar.cifar_jax import models from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter @@ -186,10 +186,10 @@ def _eval_model( @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_spec(), # batch - sharding_utils.get_replicated_sharding(), # model_state - sharding_utils.get_naive_sharding_spec(), # rng + jax_sharding_utils.get_replicated_sharding(), # params + jax_sharding_utils.get_batch_sharding(), # batch + jax_sharding_utils.get_replicated_sharding(), # model_state + jax_sharding_utils.get_batch_sharding(), # rng ), ) def _per_device_eval_model( diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 30ecd2b00..f02860feb 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -11,7 +11,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.workloads.criteo1tb.criteo1tb_jax import models -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload @@ -106,7 +106,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return sharding_utils.shard_replicated(initial_params), None + return jax_sharding_utils.shard(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_7' @@ -132,11 +132,11 @@ def model_fn( @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_replicated_sharding(), - sharding_utils.get_naive_sharding_spec(), + jax_sharding_utils.get_replicated_sharding(), + jax_sharding_utils.get_batch_sharding(), ), static_argnums=(0,), - out_shardings=sharding_utils.get_replicated_sharding()) + out_shardings=jax_sharding_utils.get_replicated_sharding()) def _eval_batch_jitted(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> spec.Tensor: diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 2709ef1e6..00a4ca708 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -10,7 +10,7 @@ from algoperf import param_utils from algoperf import spec -from algoperf import sharding_utils +from algoperf import jax_sharding_utils import algoperf.random_utils as prng from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim @@ -40,7 +40,7 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = sharding_utils.shard_replicated(params) + params = jax_sharding_utils.shard(params) return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -96,11 +96,11 @@ def loss_fn( @functools.partial( jax.jit, - in_shardings=(sharding_utils.get_replicated_sharding(), - sharding_utils.get_naive_sharding_spec(), - sharding_utils.get_replicated_sharding()), + in_shardings=(jax_sharding_utils.get_replicated_sharding(), + jax_sharding_utils.get_batch_sharding(), + jax_sharding_utils.get_replicated_sharding()), static_argnums=(0,), - out_shardings=sharding_utils.get_replicated_sharding()) + out_shardings=jax_sharding_utils.get_replicated_sharding()) def _eval_model(self, params: spec.Tensor, batch: Dict[str, spec.Tensor], diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 87b9b82bd..eb2d5809c 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -20,7 +20,7 @@ from algoperf import param_utils from algoperf import random_utils as prng -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline @@ -103,14 +103,14 @@ def init_model_fn( model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - mesh = sharding_utils.get_mesh() + mesh = jax_sharding_utils.get_mesh() params = jax.tree_map( lambda x: jax.device_put(x, - sharding_utils.get_replicated_sharding(mesh)), + jax_sharding_utils.get_replicated_sharding(mesh)), params) model_state = jax.tree_map( lambda x: jax.device_put(x, - sharding_utils.get_replicated_sharding(mesh)), + jax_sharding_utils.get_replicated_sharding(mesh)), model_state) return params, model_state @@ -120,13 +120,13 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_spec(), # batch - sharding_utils.get_replicated_sharding(), # model_state - sharding_utils.get_replicated_sharding(), # rng + jax_sharding_utils.get_replicated_sharding(), # params + jax_sharding_utils.get_batch_sharding(), # batch + jax_sharding_utils.get_replicated_sharding(), # model_state + jax_sharding_utils.get_replicated_sharding(), # rng ), static_argnums=(0,), - out_shardings=sharding_utils.get_replicated_sharding()) + out_shardings=jax_sharding_utils.get_replicated_sharding()) def _eval_model(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 1a2ce6342..1cc705aa9 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -8,7 +8,7 @@ import jax.numpy as jnp from algoperf import param_utils -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload @@ -46,8 +46,8 @@ def init_model_fn( params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = sharding_utils.shard_replicated(params) - model_state = sharding_utils.shard_replicated(model_state) + params = jax_sharding_utils.shard(params) + model_state = jax_sharding_utils.shard(model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index d1bf29ba4..19ad06eca 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -12,7 +12,7 @@ from algoperf import data_utils from algoperf import param_utils -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.librispeech_conformer import metrics from algoperf.workloads.librispeech_conformer import workload @@ -94,8 +94,8 @@ def init_model_fn( self._param_types = param_utils.jax_param_types(self._param_shapes) # Add sharding - params = sharding_utils.shard_replicated(params) - model_state = sharding_utils.shard_replicated(model_state) + params = jax_sharding_utils.shard(params) + model_state = jax_sharding_utils.shard(model_state) return params, model_state @@ -310,12 +310,12 @@ def greedy_decode( @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_spec(), # batch - sharding_utils.get_replicated_sharding(), # model_state - sharding_utils.get_replicated_sharding(), # rng + jax_sharding_utils.get_replicated_sharding(), # params + jax_sharding_utils.get_batch_sharding(), # batch + jax_sharding_utils.get_replicated_sharding(), # model_state + jax_sharding_utils.get_replicated_sharding(), # rng ), - out_shardings=sharding_utils.get_naive_sharding_spec(), + out_shardings=jax_sharding_utils.get_batch_sharding(), static_argnums=(0,)) def _eval_step( self, diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 5c34cbd74..d636c038c 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -10,7 +10,7 @@ from algoperf import param_utils from algoperf import spec -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models @@ -55,8 +55,8 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = sharding_utils.shard_replicated(model_state) - params = sharding_utils.shard_replicated(params) + model_state = jax_sharding_utils.shard(model_state) + params = jax_sharding_utils.shard(params) return params, model_state def model_fn_ref( @@ -108,7 +108,7 @@ def model_fn( use_running_average_bn=use_running_average_bn) model_fn_sharded = shard_map(model_fn_partial, - sharding_utils.get_mesh(), + jax_sharding_utils.get_mesh(), in_specs=(None, P('batch'), None), out_specs=(P('batch'), None), ) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 30f07bfd4..cbf80fe52 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -10,7 +10,7 @@ import optax from algoperf import param_utils -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.mnist.workload import BaseMnistWorkload @@ -103,10 +103,10 @@ def loss_fn( @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_spec(), # batch - sharding_utils.get_replicated_sharding(), # model_state - sharding_utils.get_naive_sharding_spec(), # rng + jax_sharding_utils.get_replicated_sharding(), # params + jax_sharding_utils.get_batch_sharding(), # batch + jax_sharding_utils.get_replicated_sharding(), # model_state + jax_sharding_utils.get_batch_sharding(), # rng ), static_argnums=(0,)) def _eval_model( diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index a7b8084fd..b8795ed38 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -8,7 +8,7 @@ import jraph import optax -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import param_utils from algoperf import spec from algoperf.workloads.ogbg import metrics @@ -46,7 +46,7 @@ def init_model_fn( params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = sharding_utils.shard_replicated(params) + params = jax_sharding_utils.shard(params) return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -111,12 +111,12 @@ def _eval_metric(self, labels, logits, masks): @functools.partial( jax.jit, - in_shardings=(sharding_utils.get_replicated_sharding(), - sharding_utils.get_naive_sharding_spec(), - sharding_utils.get_replicated_sharding(), - sharding_utils.get_replicated_sharding()), + in_shardings=(jax_sharding_utils.get_replicated_sharding(), + jax_sharding_utils.get_batch_sharding(), + jax_sharding_utils.get_replicated_sharding(), + jax_sharding_utils.get_replicated_sharding()), static_argnums=(0,), - out_shardings=sharding_utils.get_replicated_sharding(), + out_shardings=jax_sharding_utils.get_replicated_sharding(), ) def _eval_batch(self, params, batch, model_state, rng): return super()._eval_batch(params, batch, model_state, rng) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index dce08a677..30d306fbb 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -14,7 +14,7 @@ import optax from algoperf import param_utils -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_jax import decode @@ -72,8 +72,8 @@ def compute_weighted_cross_entropy( @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_spec(), # batch + jax_sharding_utils.get_replicated_sharding(), # params + jax_sharding_utils.get_batch_sharding(), # batch ), static_argnums=(0,), # self ) @@ -100,7 +100,7 @@ def eval_step(self, @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_naive_sharding_spec(), # inputs + jax_sharding_utils.get_batch_sharding(), # inputs ), static_argnums=( 0, @@ -112,9 +112,9 @@ def initialize_cache(self, """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] - dummy_inputs = sharding_utils.shard_naive( + dummy_inputs = jax_sharding_utils.shard_naive( jnp.ones(inputs.shape, jnp.float32)) - dummy_targets = sharding_utils.shard_naive( + dummy_targets = jax_sharding_utils.shard_naive( jnp.ones(target_shape, jnp.float32)) initial_variables = models.Transformer(config).init( jax.random.PRNGKey(0), dummy_inputs, dummy_targets) @@ -191,14 +191,14 @@ def translate_and_calculate_bleu(self, for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - cache = sharding_utils.shard_naive(cache) + cache = jax_sharding_utils.shard_naive(cache) if jitted_predict_step is None: jitted_predict_step = jax.jit( self.predict_step, in_shardings=( - sharding_utils.get_naive_sharding_spec(), # inputs - sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_tree(cache), # cache + jax_sharding_utils.get_batch_sharding(), # inputs + jax_sharding_utils.get_replicated_sharding(), # params + jax_sharding_utils.get_naive_sharding_tree(cache), # cache ), static_argnums=( 3, # eos_id @@ -260,8 +260,8 @@ def init_model_fn( params_rng, dropout_rng = jax.random.split(rng) inputs = jnp.ones(input_shape, jnp.float32) targets = jnp.ones(target_shape, jnp.float32) - sharded_inputs = sharding_utils.shard_naive(inputs) - sharded_targets = sharding_utils.shard_naive(targets) + sharded_inputs = jax_sharding_utils.shard_naive(inputs) + sharded_targets = jax_sharding_utils.shard_naive(targets) initial_variables = jax.jit( self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, @@ -271,7 +271,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = sharding_utils.shard_replicated(initial_params) + params = jax_sharding_utils.shard(initial_params) return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 3423f6e85..16105c802 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -11,7 +11,7 @@ from jax.sharding import PartitionSpec as P import optax -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -136,7 +136,7 @@ def update_params( grad_clip = None # Set up mesh and sharding - mesh = sharding_utils.get_mesh() + mesh = jax_sharding_utils.get_mesh() replicated = NamedSharding(mesh, P()) # No partitioning sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 49e46109b..1e42d7b94 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -11,7 +11,7 @@ from jax.sharding import PartitionSpec as P import optax -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -174,7 +174,7 @@ def update_params( else: grad_clip = None - mesh = sharding_utils.get_mesh() + mesh = jax_sharding_utils.get_mesh() # Create shardings for each argument replicated = NamedSharding(mesh, P()) # No partitioning sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 1cfa5deca..20f015821 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import optax -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -91,10 +91,10 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - mesh = sharding_utils.get_mesh() + mesh = jax_sharding_utils.get_mesh() # Create shardings for each argument - replicated = sharding_utils.get_replicated_sharding(mesh) # No partitioning - sharded = sharding_utils.get_naive_sharding_spec( + replicated = jax_sharding_utils.get_replicated_sharding(mesh) # No partitioning + sharded = jax_sharding_utils.get_batch_sharding( mesh) # Partition along batch dimension # Create the sharding rules for each argument From a3a9b9fadbb994e2d05f14d7ecbb5477486fbc11 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 22:16:05 +0000 Subject: [PATCH 41/86] simplify changes in cifar jax --- .../cifar/cifar_jax/input_pipeline.py | 6 +- .../workloads/cifar/cifar_jax/workload.py | 104 ++++++++---------- 2 files changed, 47 insertions(+), 63 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 459ed9266..24c3d5603 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -8,6 +8,7 @@ import functools from typing import Dict, Iterator, Tuple +from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds @@ -170,7 +171,6 @@ def create_input_iter( functools.partial( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) - # FIXME(rka97): Figure out how to do prefetching+sharding. - # TODO (kasimbeg) - # it = jax_utils.prefetch_to_device(it, 2) + + it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index 5248eeec4..41ae2c60e 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -28,10 +28,9 @@ def _build_cifar_dataset( data_dir: str, batch_size: int, cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None ) -> Iterator[Dict[str, spec.Tensor]]: - data_dir = data_dir + "/cifar10" - ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir) + ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) train = split == 'train' assert self.num_train_examples + self.num_validation_examples == 50000 if split in ['train', 'eval_train']: @@ -92,15 +91,17 @@ def init_model_fn( model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) - variables = jax.jit(model.init)({"params": rng}, + variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) + model_state = jax_sharding_utils.replicate(params) + params = jax_sharding_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: - return param_key == "Dense_0" + return param_key == 'Dense_0' def model_fn( self, @@ -114,11 +115,11 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng - variables = {"params": params, **model_state} + variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch["inputs"], + augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn) @@ -126,7 +127,7 @@ def model_fn( else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch["inputs"], + augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn) @@ -139,15 +140,13 @@ def loss_fn( label_batch: spec.Tensor, # Dense or one-hot labels. logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0, - ) -> Dict[str, spec.Tensor]: # differentiable + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). - Return {'summed': scalar summed loss, - 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( @@ -160,66 +159,51 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - "summed": summed_loss, - "n_valid_examples": n_valid_examples, - "per_example": per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } def _compute_metrics(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor) -> Dict[str, spec.Tensor]: - summed_loss = self.loss_fn(labels, logits, weights)["summed"] + summed_loss = self.loss_fn(labels, logits, weights)['summed'] # Number of correct predictions. accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) - return jnp.array(summed_loss), jnp.array(accuracy) + metrics = { + 'loss': summed_loss, + 'accuracy': accuracy, + } + return metrics + @functools.partial( + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicated_sharding(), # params + jax_sharding_utils.get_batch_sharding(), # batch + jax_sharding_utils.get_replicated_sharding(), # model_state + jax_sharding _utils.get_batch_sharding(), # rng + ), + ) def _eval_model( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" - - @functools.partial( - jax.jit, - in_shardings=( - jax_sharding_utils.get_replicated_sharding(), # params - jax_sharding_utils.get_batch_sharding(), # batch - jax_sharding_utils.get_replicated_sharding(), # model_state - jax_sharding_utils.get_batch_sharding(), # rng - ), - ) - def _per_device_eval_model( - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False, - ) - weights = batch.get("weights") - if weights is None: - weights = jnp.ones(len(logits)) - return self._compute_metrics(logits, batch["targets"], weights) - - losses, accuracies = _per_device_eval_model(params, batch, model_state, rng) - metrics = { - "loss": - jnp.mean(losses, axis=0) if losses.ndim > 0 else losses, - "accuracy": - (jnp.mean(accuracies, axis=0) if accuracies.ndim > 0 else accuracies - ), - } - return metrics + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False) + weights = batch.get('weights') + if weights is None: + weights = jnp.ones(len(logits)) + return self._compute_metrics(logits, batch['targets'], weights) def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, From 0a340a2f9f82629d9d5971fc8a605e2e3cb5ea38 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 22:29:17 +0000 Subject: [PATCH 42/86] small fix --- algoperf/workloads/cifar/cifar_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index 41ae2c60e..ffd313d72 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -183,7 +183,7 @@ def _compute_metrics(self, jax_sharding_utils.get_replicated_sharding(), # params jax_sharding_utils.get_batch_sharding(), # batch jax_sharding_utils.get_replicated_sharding(), # model_state - jax_sharding _utils.get_batch_sharding(), # rng + jax_sharding_utils.get_batch_sharding(), # rng ), ) def _eval_model( From 60c1cce842a1e82cf81cd17b377abbc0ff3883c2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 22:32:19 +0000 Subject: [PATCH 43/86] rename sharding utils --- algoperf/jax_sharding_utils.py | 37 ++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 algoperf/jax_sharding_utils.py diff --git a/algoperf/jax_sharding_utils.py b/algoperf/jax_sharding_utils.py new file mode 100644 index 000000000..4fbbe0cc6 --- /dev/null +++ b/algoperf/jax_sharding_utils.py @@ -0,0 +1,37 @@ +"""Utilities for dealing with sharding in JAX.""" + +import jax +from jax.sharding import NamedSharding, PartitionSpec as P + + +def get_replicate_sharding(): + """Returns a sharding spec that replicates data across all devices.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return NamedSharding(mesh, P()) + + +def get_batch_dim_sharding(): + """Returns a sharding spec that shards data along the first axis.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return NamedSharding(mesh, P('batch')) + + +def shard_along_batch_dim(x): + """Shards a tensor across all devices.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return jax.tree.map( + lambda x: jax.device_put(x, NamedSharding(mesh, P('batch')))) + + +def replicate(x): + """Replicates tensor across all devices.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return jax.tree.map( + lambda x: jax.device_put(x, NamedSharding(mesh, P()))) + + +def disp_shard_info(x: jax.Array): + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n") \ No newline at end of file From 1edb72434347eecd80e125225eac5e1f444b9303 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 22:47:34 +0000 Subject: [PATCH 44/86] fix sharding rename --- algoperf/workloads/cifar/cifar_jax/workload.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index ffd313d72..3e4da5684 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -180,10 +180,10 @@ def _compute_metrics(self, @functools.partial( jax.jit, in_shardings=( - jax_sharding_utils.get_replicated_sharding(), # params - jax_sharding_utils.get_batch_sharding(), # batch - jax_sharding_utils.get_replicated_sharding(), # model_state - jax_sharding_utils.get_batch_sharding(), # rng + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_batch_dim_sharding(), # rng ), ) def _eval_model( From 49864fb430f6ef3ea6b667f84573d45ee804847e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 3 Apr 2025 22:55:14 +0000 Subject: [PATCH 45/86] refactoring --- .../criteo1tb/criteo1tb_jax/workload.py | 8 ++++---- .../imagenet_resnet/imagenet_jax/workload.py | 15 +++++++-------- .../librispeech_jax/workload.py | 10 +++++----- .../librispeech_jax/workload.py | 2 +- algoperf/workloads/mnist/mnist_jax/workload.py | 8 ++++---- algoperf/workloads/ogbg/ogbg_jax/workload.py | 10 +++++----- algoperf/workloads/wmt/wmt_jax/workload.py | 16 ++++++++-------- 7 files changed, 34 insertions(+), 35 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index f02860feb..9c5bc7690 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -106,7 +106,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_sharding_utils.shard(initial_params), None + return jax_sharding_utils.shard_along_batch_dim(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_7' @@ -132,11 +132,11 @@ def model_fn( @functools.partial( jax.jit, in_shardings=( - jax_sharding_utils.get_replicated_sharding(), - jax_sharding_utils.get_batch_sharding(), + jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), ), static_argnums=(0,), - out_shardings=jax_sharding_utils.get_replicated_sharding()) + out_shardings=jax_sharding_utils.get_replicate_sharding()) def _eval_batch_jitted(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> spec.Tensor: diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index eb2d5809c..0d6e8912d 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -103,14 +103,13 @@ def init_model_fn( model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - mesh = jax_sharding_utils.get_mesh() params = jax.tree_map( lambda x: jax.device_put(x, - jax_sharding_utils.get_replicated_sharding(mesh)), + jax_sharding_utils.get_replicate_sharding()), params) model_state = jax.tree_map( lambda x: jax.device_put(x, - jax_sharding_utils.get_replicated_sharding(mesh)), + jax_sharding_utils.get_replicate_sharding()), model_state) return params, model_state @@ -120,13 +119,13 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: @functools.partial( jax.jit, in_shardings=( - jax_sharding_utils.get_replicated_sharding(), # params - jax_sharding_utils.get_batch_sharding(), # batch - jax_sharding_utils.get_replicated_sharding(), # model_state - jax_sharding_utils.get_replicated_sharding(), # rng + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_replicate_sharding(), # rng ), static_argnums=(0,), - out_shardings=jax_sharding_utils.get_replicated_sharding()) + out_shardings=jax_sharding_utils.get_replicate_sharding()) def _eval_model(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 19ad06eca..dbaa58ab6 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -310,12 +310,12 @@ def greedy_decode( @functools.partial( jax.jit, in_shardings=( - jax_sharding_utils.get_replicated_sharding(), # params - jax_sharding_utils.get_batch_sharding(), # batch - jax_sharding_utils.get_replicated_sharding(), # model_state - jax_sharding_utils.get_replicated_sharding(), # rng + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_replicate_sharding(), # rng ), - out_shardings=jax_sharding_utils.get_batch_sharding(), + out_shardings=jax_sharding_utils.get_batch_dim_sharding(), static_argnums=(0,)) def _eval_step( self, diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d636c038c..e324de76e 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -108,7 +108,7 @@ def model_fn( use_running_average_bn=use_running_average_bn) model_fn_sharded = shard_map(model_fn_partial, - jax_sharding_utils.get_mesh(), + jax.sharding.Mesh(jax.devices(), ('batch')), in_specs=(None, P('batch'), None), out_specs=(P('batch'), None), ) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index cbf80fe52..771fd8234 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -103,10 +103,10 @@ def loss_fn( @functools.partial( jax.jit, in_shardings=( - jax_sharding_utils.get_replicated_sharding(), # params - jax_sharding_utils.get_batch_sharding(), # batch - jax_sharding_utils.get_replicated_sharding(), # model_state - jax_sharding_utils.get_batch_sharding(), # rng + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_batch_dim_sharding(), # rng ), static_argnums=(0,)) def _eval_model( diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index b8795ed38..e25ac3b4e 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -111,12 +111,12 @@ def _eval_metric(self, labels, logits, masks): @functools.partial( jax.jit, - in_shardings=(jax_sharding_utils.get_replicated_sharding(), - jax_sharding_utils.get_batch_sharding(), - jax_sharding_utils.get_replicated_sharding(), - jax_sharding_utils.get_replicated_sharding()), + in_shardings=(jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_replicate_sharding()), static_argnums=(0,), - out_shardings=jax_sharding_utils.get_replicated_sharding(), + out_shardings=jax_sharding_utils.get_replicate_sharding(), ) def _eval_batch(self, params, batch, model_state, rng): return super()._eval_batch(params, batch, model_state, rng) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 30d306fbb..96dad79c1 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -100,7 +100,7 @@ def eval_step(self, @functools.partial( jax.jit, in_shardings=( - jax_sharding_utils.get_batch_sharding(), # inputs + jax_sharding_utils.get_batch_dim_sharding(), # inputs ), static_argnums=( 0, @@ -112,9 +112,9 @@ def initialize_cache(self, """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] - dummy_inputs = jax_sharding_utils.shard_naive( + dummy_inputs = jax_sharding_utils.shard_along_batch_dim( jnp.ones(inputs.shape, jnp.float32)) - dummy_targets = jax_sharding_utils.shard_naive( + dummy_targets = jax_sharding_utils.shard_along_batch_dim( jnp.ones(target_shape, jnp.float32)) initial_variables = models.Transformer(config).init( jax.random.PRNGKey(0), dummy_inputs, dummy_targets) @@ -196,8 +196,8 @@ def translate_and_calculate_bleu(self, jitted_predict_step = jax.jit( self.predict_step, in_shardings=( - jax_sharding_utils.get_batch_sharding(), # inputs - jax_sharding_utils.get_replicated_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # inputs + jax_sharding_utils.get_replicate_sharding(), # params jax_sharding_utils.get_naive_sharding_tree(cache), # cache ), static_argnums=( @@ -260,8 +260,8 @@ def init_model_fn( params_rng, dropout_rng = jax.random.split(rng) inputs = jnp.ones(input_shape, jnp.float32) targets = jnp.ones(target_shape, jnp.float32) - sharded_inputs = jax_sharding_utils.shard_naive(inputs) - sharded_targets = jax_sharding_utils.shard_naive(targets) + sharded_inputs = jax_sharding_utils.shard_along_batch_dim(inputs) + sharded_targets = jax_sharding_utils.shard_along_batch_dim(targets) initial_variables = jax.jit( self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, @@ -271,7 +271,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = jax_sharding_utils.shard(initial_params) + params = jax_sharding_utils.shard_along_batch_dim(initial_params) return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: From 7820ac68825ef9fc14dd177a897344275928690b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 4 Apr 2025 00:05:14 +0000 Subject: [PATCH 46/86] modifications to cifar --- algoperf/data_utils.py | 2 +- algoperf/jax_sharding_utils.py | 2 +- .../cifar/cifar_jax/input_pipeline.py | 1 - .../workloads/cifar/cifar_jax/workload.py | 64 +++++++++++-------- .../paper_baselines/adamw/jax/submission.py | 4 +- 5 files changed, 44 insertions(+), 29 deletions(-) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 72c0043e5..9a7b91b15 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -61,7 +61,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - return jax.device_put(x, jax.sharding_utils.get_batch_dim_sharding()) + return jax.device_put(x, jax_sharding_utils.get_batch_dim_sharding()) return jax.tree.map(_prepare, batch) diff --git a/algoperf/jax_sharding_utils.py b/algoperf/jax_sharding_utils.py index 4fbbe0cc6..e318f12ac 100644 --- a/algoperf/jax_sharding_utils.py +++ b/algoperf/jax_sharding_utils.py @@ -27,7 +27,7 @@ def replicate(x): """Replicates tensor across all devices.""" mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) return jax.tree.map( - lambda x: jax.device_put(x, NamedSharding(mesh, P()))) + lambda x: jax.device_put(x, NamedSharding(mesh, P())), x) def disp_shard_info(x: jax.Array): diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 24c3d5603..8eec88f28 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -172,5 +172,4 @@ def create_input_iter( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) - it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index 3e4da5684..48cee94b4 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -31,6 +31,7 @@ def _build_cifar_dataset( repeat_final_dataset: Optional[bool] = None ) -> Iterator[Dict[str, spec.Tensor]]: ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) + ds_builder.download_and_prepare() train = split == 'train' assert self.num_train_examples + self.num_validation_examples == 50000 if split in ['train', 'eval_train']: @@ -177,33 +178,46 @@ def _compute_metrics(self, } return metrics - @functools.partial( - jax.jit, - in_shardings=( - jax_sharding_utils.get_replicate_sharding(), # params - jax_sharding_utils.get_batch_dim_sharding(), # batch - jax_sharding_utils.get_replicate_sharding(), # model_state - jax_sharding_utils.get_batch_dim_sharding(), # rng - ), - ) def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) - weights = batch.get('weights') - if weights is None: - weights = jnp.ones(len(logits)) - return self._compute_metrics(logits, batch['targets'], weights) + + @functools.partial( + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_batch_dim_sharding(), # rng + ), + ) + def _eval_model_jitted( + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + """Return the mean accuracy and loss as a dict.""" + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False) + weights = batch.get('weights') + if weights is None: + weights = jnp.ones(len(logits)) + return self._compute_metrics(logits, batch['targets'], weights) + + metrics = _eval_model_jitted(params, + batch, + model_state, + rng) + return jax.tree.map(lambda x: x.item(), metrics) def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 16105c802..9e85d3d84 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -136,7 +136,7 @@ def update_params( grad_clip = None # Set up mesh and sharding - mesh = jax_sharding_utils.get_mesh() + mesh = jax.sharding.Mesh(jax.devices(), ('batch')) replicated = NamedSharding(mesh, P()) # No partitioning sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension @@ -228,6 +228,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 32 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From 0a2043c7251d0e4911273a4734e31b9046e96946 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 5 Apr 2025 01:13:56 +0000 Subject: [PATCH 47/86] fix --- algoperf/workloads/mnist/mnist_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 771fd8234..ad2d7fc8a 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -135,4 +135,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree.map(lambda: float(x.item() / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x.item() / num_examples), total_metrics) From 95037bf550e5e1644ccd6e7da7a541a9a5413976 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 5 Apr 2025 06:13:53 +0000 Subject: [PATCH 48/86] clean up and small fixes --- algoperf/jax_sharding_utils.py | 4 ++-- algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- algoperf/workloads/fastmri/fastmri_jax/workload.py | 10 +++++----- .../workloads/imagenet_vit/imagenet_jax/workload.py | 4 ++-- .../librispeech_conformer/librispeech_jax/workload.py | 6 +++--- algoperf/workloads/ogbg/ogbg_jax/workload.py | 2 +- algoperf/workloads/wmt/wmt_jax/workload.py | 7 +++---- 7 files changed, 17 insertions(+), 18 deletions(-) diff --git a/algoperf/jax_sharding_utils.py b/algoperf/jax_sharding_utils.py index e318f12ac..6c90c5cd7 100644 --- a/algoperf/jax_sharding_utils.py +++ b/algoperf/jax_sharding_utils.py @@ -20,7 +20,7 @@ def shard_along_batch_dim(x): """Shards a tensor across all devices.""" mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) return jax.tree.map( - lambda x: jax.device_put(x, NamedSharding(mesh, P('batch')))) + lambda x: jax.device_put(x, NamedSharding(mesh, P('batch'))), x) def replicate(x): @@ -30,7 +30,7 @@ def replicate(x): lambda x: jax.device_put(x, NamedSharding(mesh, P())), x) -def disp_shard_info(x: jax.Array): +def display_shard_info(x: jax.Array): """Displays shard info of a jax array.""" for shard in x.addressable_shards: print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 9c5bc7690..723326120 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -106,7 +106,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_sharding_utils.shard_along_batch_dim(initial_params), None + return jax_sharding_utils.replicate(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_7' diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 00a4ca708..1349cef64 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -40,7 +40,7 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = jax_sharding_utils.shard(params) + params = jax_sharding_utils.replicate(params) return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -96,11 +96,11 @@ def loss_fn( @functools.partial( jax.jit, - in_shardings=(jax_sharding_utils.get_replicated_sharding(), - jax_sharding_utils.get_batch_sharding(), - jax_sharding_utils.get_replicated_sharding()), + in_shardings=(jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + jax_sharding_utils.get_replicate_sharding()), static_argnums=(0,), - out_shardings=jax_sharding_utils.get_replicated_sharding()) + out_shardings=jax_sharding_utils.get_replicate_sharding()) def _eval_model(self, params: spec.Tensor, batch: Dict[str, spec.Tensor], diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 1cc705aa9..c4d823319 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -46,8 +46,8 @@ def init_model_fn( params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = jax_sharding_utils.shard(params) - model_state = jax_sharding_utils.shard(model_state) + params = jax_sharding_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index dbaa58ab6..758344a23 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -94,8 +94,8 @@ def init_model_fn( self._param_types = param_utils.jax_param_types(self._param_shapes) # Add sharding - params = jax_sharding_utils.shard(params) - model_state = jax_sharding_utils.shard(model_state) + params = jax_sharding_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) return params, model_state @@ -342,7 +342,7 @@ def _eval_step( decoded, 'decoded_paddings': decoded_paddings, - 'targets': + 'targets': targets, 'target_paddings': target_paddings, diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e25ac3b4e..abfd70504 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -46,7 +46,7 @@ def init_model_fn( params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = jax_sharding_utils.shard(params) + params = jax_sharding_utils.replicate(params) return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 96dad79c1..36f5b8606 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -72,8 +72,8 @@ def compute_weighted_cross_entropy( @functools.partial( jax.jit, in_shardings=( - jax_sharding_utils.get_replicated_sharding(), # params - jax_sharding_utils.get_batch_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch ), static_argnums=(0,), # self ) @@ -191,14 +191,13 @@ def translate_and_calculate_bleu(self, for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - cache = jax_sharding_utils.shard_naive(cache) if jitted_predict_step is None: jitted_predict_step = jax.jit( self.predict_step, in_shardings=( jax_sharding_utils.get_batch_dim_sharding(), # inputs jax_sharding_utils.get_replicate_sharding(), # params - jax_sharding_utils.get_naive_sharding_tree(cache), # cache + jax_sharding_utils.get_replicate_sharding(), # cache ), static_argnums=( 3, # eos_id From e79c7616776a9bbcd8454ddc2d0e6b8fa314a5db Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 5 Apr 2025 06:14:17 +0000 Subject: [PATCH 49/86] add test for sharding invariance --- tests/test_jax_sharding_invariance.py | 90 +++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/test_jax_sharding_invariance.py diff --git a/tests/test_jax_sharding_invariance.py b/tests/test_jax_sharding_invariance.py new file mode 100644 index 000000000..823ef0843 --- /dev/null +++ b/tests/test_jax_sharding_invariance.py @@ -0,0 +1,90 @@ +"""Tests for sharding consistency in JAX workloads. + +Specifically this will test the model_init functions, and input_pipeline. +""" +import copy +import os +import sys + +from absl import flags +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized + +from algoperf.profiler import PassThroughProfiler +import submission_runner +from algoperf.workloads.workloads import import_workload +from algoperf.workloads.workloads import BASE_WORKLOADS_DIR +from algoperf.workloads.workloads import WORKLOADS + +FLAGS = flags.FLAGS +# Needed to avoid UnparsedFlagAccessError +# (see https://github.com/google/model_search/pull/8). +FLAGS(sys.argv) + +FRAMEWORK = 'jax' # Can extend to pytorch later + + +test_case = dict(testcase_name='test_ogbg', + workload='ogbg') + + +class SubmissionRunnerTest(parameterized.TestCase): + """Tests for reference submissions.""" + + + @parameterized.named_parameters(test_case) + def test_invariance(self, workload_name): + workload_name = 'ogbg' + dataset_dir = f'/data/{workload_name}' + workload_metadata = copy.deepcopy(WORKLOADS[workload_name]) + workload_metadata['workload_path'] = os.path.join(BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + FRAMEWORK, + 'workload.py') + workload = import_workload(workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs={}) + + rng = jax.random.PRNGKey(0) + initial_params, model_state = workload.init_model_fn(rng) + data_iter = workload._build_input_queue(rng, 'train', dataset_dir, 32) + batch = next(data_iter) + inputs = batch['inputs'] + + def forward_pass(params, + batch, + model_state, + rng,): + logits, _ = workload.model_fn(initial_params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + return logits + + forward_pass_jitted = jax.jit(forward_pass, + in_shardings=(jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_replicate_sharding(), + ), + out_shardings=jax_sharding_utils.get_batch_dim_sharding()) + + logits = forward_pass(initial_params, + batch, + model_state, + rng,) + + logits_jitted = forward_pass_jitted(initial_params, + batch, + model_state, + rng,) + + jax.debug.visualize_array_sharding(logits_jitted) + + equal = jnp.allclose(logits, logits_jitted, atol=1e-6) + + +if __name__ == '__main__': + absltest.main() \ No newline at end of file From 110e79296d2c3c0349fdbcb86a23d9aac6561297 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 8 Apr 2025 17:45:44 +0000 Subject: [PATCH 50/86] fix --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 489c27fab..45be34498 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,7 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.5.3, + "jax==0.5.3", "algoperf[jax_core_deps]", ] jax_gpu = [ From 9c91c657101e461fd21e2e4c5729fe57e21df85e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Mon, 14 Apr 2025 12:36:53 -0700 Subject: [PATCH 51/86] Update pyproject.toml syntax fix --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 489c27fab..45be34498 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,7 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.5.3, + "jax==0.5.3", "algoperf[jax_core_deps]", ] jax_gpu = [ From 21bb997cc700b6ccae0f177b8f5e4d509998391b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Mon, 14 Apr 2025 12:49:44 -0700 Subject: [PATCH 52/86] Update workload.py fix --- .../librispeech_deepspeech/librispeech_jax/workload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index e324de76e..2c45e9e4d 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -55,8 +55,8 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_sharding_utils.shard(model_state) - params = jax_sharding_utils.shard(params) + model_state = jax_sharding_utils.shard_along_batch_dim(model_state) + params = jax_sharding_utils.shard_along_batch_dim(params) return params, model_state def model_fn_ref( From eb5691942919b3793dc1e6b31b497652b3caf9f1 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Mon, 14 Apr 2025 12:51:57 -0700 Subject: [PATCH 53/86] Update workload.py --- .../librispeech_deepspeech/librispeech_jax/workload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 2c45e9e4d..8a54c6455 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -55,8 +55,8 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_sharding_utils.shard_along_batch_dim(model_state) - params = jax_sharding_utils.shard_along_batch_dim(params) + model_state = jax_sharding_utils.replicate(model_state) + params = jax_sharding_utils.replicate(params) return params, model_state def model_fn_ref( From 1277cc2345e18ad14d3c7dbe525c63b77d1608ba Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 19 May 2025 16:30:16 +0000 Subject: [PATCH 54/86] upgrade jax --- README.md | 2 +- pyproject.toml | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index f03890cbb..5823c173d 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ Both options are described in detail in the [**Getting Started**](/docs/GETTING_ ```bash pip3 install -e '.[pytorch_cpu]' -pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' +pip3 install -e '.[jax_gpu]' pip3 install -e '.[full]' ``` diff --git a/pyproject.toml b/pyproject.toml index 45be34498..0e36837de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "docker==7.1.0", "numpy>=2.0.2", "pandas>=2.0.1", - "tensorflow==2.18.0", + "tensorflow==2.19.0", "tensorflow-datasets==4.9.7", "tensorflow-probability==0.20.0", "tensorflow-addons==0.20.0", @@ -92,25 +92,25 @@ fastmri = ["h5py==3.12.0", "scikit-image==0.24.0"] ogbg = ["jraph==0.0.6.dev0", "scikit-learn==1.5.2"] librispeech_conformer = [ "sentencepiece==0.2.0", - "tensorflow-text==2.18.0", + "tensorflow-text==2.19.0", "pydub==0.25.1", ] -wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] +wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] # Frameworks jax_core_deps = [ "flax==0.8.4", "optax==0.2.2", "chex==0.1.86", - "ml_dtypes==0.4.1", + "ml_dtypes==0.5.1", "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.5.3", + "jax==0.6.0", "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax[cuda12]==0.5.3", + "jax[cuda12]==0.6.0", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] From def4ac52de8086d7df9529dce330551a7ae542a4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 19 May 2025 17:25:42 +0000 Subject: [PATCH 55/86] update dockerfile --- docker/Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 76bc5cfe0..623daac2f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docker build -t --build-arg framework=pytorch # To build Docker image -FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 +FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu20.04 # Installing machine packages RUN echo "Setting up machine" @@ -71,7 +71,7 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ + && pip install -e '.[jax_gpu]' \ && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ @@ -81,7 +81,7 @@ RUN if [ "$framework" = "jax" ] ; then \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ + && pip install -e '.[jax_gpu]' \ && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ From 450cbeecd4b4758438c4b63c1ec21bb0d7f977f7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 19 May 2025 18:34:56 +0000 Subject: [PATCH 56/86] remove extra installs --- docker/Dockerfile | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 623daac2f..db28b77c1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,15 +16,7 @@ RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg # Install prerequisites RUN apt-get update && apt-get install -y \ wget \ - build-essential \ - zlib1g-dev \ - libncurses5-dev \ - libssl-dev \ - libreadline-dev \ - libffi-dev \ curl \ - libbz2-dev \ - liblzma-dev \ vim # Download and install Python 3.11 From 89718e78a8c93ec2b372b06ea7c84d2c3f6eda61 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 20 May 2025 00:45:55 +0000 Subject: [PATCH 57/86] update jax version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0e36837de..6ed38b412 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,7 +110,7 @@ jax_cpu = [ "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax[cuda12]==0.6.0", + "jax[cuda12-local]==0.6.0", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] From 7dcf5aff30354423894a1d12d85c723d63e7d2ca Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 20 May 2025 22:05:28 +0000 Subject: [PATCH 58/86] update install commands for pytorch cpu only --- README.md | 4 ++-- docker/Dockerfile | 22 ++++++++++++++-------- pyproject.toml | 2 +- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 5823c173d..bdcfc899f 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Both options are described in detail in the [**Getting Started**](/docs/GETTING_ *TL;DR to install the Jax version for GPU run:* ```bash -pip3 install -e '.[pytorch_cpu]' +pip3 install -e '.[pytorch_cpu]' --index-url https://download.pytorch.org/whl/cpu pip3 install -e '.[jax_gpu]' pip3 install -e '.[full]' ``` @@ -63,7 +63,7 @@ pip3 install -e '.[full]' ```bash pip3 install -e '.[jax_cpu]' -pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121' +pip3 install -e '.[pytorch_gpu]' pip3 install -e '.[full]' ``` diff --git a/docker/Dockerfile b/docker/Dockerfile index db28b77c1..8b9011d72 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,6 +16,12 @@ RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg # Install prerequisites RUN apt-get update && apt-get install -y \ wget \ + build-essential \ + zlib1g-dev \ + libncurses5-dev \ + libssl-dev \ + libreadline-dev \ + libffi-dev \ curl \ vim @@ -48,10 +54,10 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ -RUN pip install --upgrade pip +RUN pip3 install --upgrade pip # Install Algorithmic efficiency repo -RUN pip install --upgrade pip +RUN pip3 install --upgrade pip RUN echo "Setting up algorithmic_efficiency repo" ARG branch="main" @@ -63,18 +69,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' \ - && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ + && pip3 install -e '.[pytorch_cpu]' --index-url https://download.pytorch.org/whl/cpu; \ + && pip3 install -e '.[jax_gpu]' \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_cpu]' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip3 install -e '.[pytorch_gpu]'; \ + && pip3 install -e '.[jax_cpu]' \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip3 install -e '.[pytorch_gpu]'; \ + && pip3 install -e '.[jax_gpu]' \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ diff --git a/pyproject.toml b/pyproject.toml index 6ed38b412..0e36837de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,7 +110,7 @@ jax_cpu = [ "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax[cuda12-local]==0.6.0", + "jax[cuda12]==0.6.0", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] From 43356889a1d903ac91840d401b039ca429782c3a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 20 May 2025 22:24:16 +0000 Subject: [PATCH 59/86] update dockerfile --- docker/Dockerfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 8b9011d72..d6cdc3d32 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -54,8 +54,6 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ -RUN pip3 install --upgrade pip - # Install Algorithmic efficiency repo RUN pip3 install --upgrade pip From 8d1fe7ef256d3179c218f83af56334d86a6b93d7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 20 May 2025 22:31:13 +0000 Subject: [PATCH 60/86] update dockerfile --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index d6cdc3d32..fdb1871b8 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -55,7 +55,7 @@ RUN mkdir -p data/ RUN mkdir -p experiment_runs/ # Install Algorithmic efficiency repo -RUN pip3 install --upgrade pip +RUN pip install --upgrade pip RUN echo "Setting up algorithmic_efficiency repo" ARG branch="main" From 240e2e568bce8fbfe66a157a06ce5dcdc73e0d1e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 20 May 2025 22:32:33 +0000 Subject: [PATCH 61/86] update dockerfile --- docker/Dockerfile | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index fdb1871b8..5f15a772b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -67,18 +67,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip3 install -e '.[pytorch_cpu]' --index-url https://download.pytorch.org/whl/cpu; \ - && pip3 install -e '.[jax_gpu]' \ + && pip install -e '.[pytorch_cpu]' --index-url https://download.pytorch.org/whl/cpu; \ + && pip install -e '.[jax_gpu]' \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip3 install -e '.[pytorch_gpu]'; \ - && pip3 install -e '.[jax_cpu]' \ + && pip install -e '.[pytorch_gpu]'; \ + && pip install -e '.[jax_cpu]' \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip3 install -e '.[pytorch_gpu]'; \ - && pip3 install -e '.[jax_gpu]' \ + && pip install -e '.[pytorch_gpu]'; \ + && pip install -e '.[jax_gpu]' \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ From cc8d6045b4489617d934f68a3900d1aa13acdc8e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 20 May 2025 22:42:09 +0000 Subject: [PATCH 62/86] update dockerfile --- docker/Dockerfile | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 5f15a772b..4fbc6d171 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -67,18 +67,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[pytorch_cpu]' --index-url https://download.pytorch.org/whl/cpu; \ - && pip install -e '.[jax_gpu]' \ + && pip install -e '.[pytorch_cpu]' --index-url https://download.pytorch.org/whl/cpu \ + && pip install -e '.[jax_gpu]'; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[pytorch_gpu]'; \ - && pip install -e '.[jax_cpu]' \ + && pip install -e '.[pytorch_gpu]' \ + && pip install -e '.[jax_cpu]'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[pytorch_gpu]'; \ - && pip install -e '.[jax_gpu]' \ + && pip install -e '.[pytorch_gpu]' \ + && pip install -e '.[jax_gpu]'; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ From fe56eaf437086e3024811c965d876bbbb82d36a4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 20 May 2025 22:51:56 +0000 Subject: [PATCH 63/86] update dockerfile --- README.md | 2 +- docker/Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index bdcfc899f..abe9ad665 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Both options are described in detail in the [**Getting Started**](/docs/GETTING_ *TL;DR to install the Jax version for GPU run:* ```bash -pip3 install -e '.[pytorch_cpu]' --index-url https://download.pytorch.org/whl/cpu +pip3 install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu pip3 install -e '.[jax_gpu]' pip3 install -e '.[full]' ``` diff --git a/docker/Dockerfile b/docker/Dockerfile index 4fbc6d171..32e4a8ddf 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -67,7 +67,7 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[pytorch_cpu]' --index-url https://download.pytorch.org/whl/cpu \ + && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ && pip install -e '.[jax_gpu]'; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ From de4c38b8482a099a8f3a3736d08aa63d92a9fe54 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 27 May 2025 19:37:37 +0000 Subject: [PATCH 64/86] modify initial model_state --- .../librispeech_deepspeech/librispeech_jax/models.py | 2 ++ .../librispeech_jax/workload.py | 12 ++++++------ pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index c8f49c830..2c7011445 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -20,6 +20,8 @@ import jax from jax.experimental import rnn import jax.numpy as jnp +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 8a54c6455..9d177eeba 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -50,8 +50,8 @@ def init_model_fn( variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state = variables[ - 'batch_stats'] if not self.layernorm_everywhere else {} + model_state = {'batch_stats': variables[ + 'batch_stats']} if not self.layernorm_everywhere else {} params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -80,6 +80,8 @@ def model_fn_ref( train=True, rngs={'dropout' : rng}, mutable=['batch_stats']) + if 'batch_stats' in new_model_state and new_model_state['batch_stats']: + new_model_state = jax.lax.pmean(new_model_state, 'batch') return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( @@ -109,11 +111,9 @@ def model_fn( model_fn_sharded = shard_map(model_fn_partial, jax.sharding.Mesh(jax.devices(), ('batch')), - in_specs=(None, P('batch'), None), - out_specs=(P('batch'), None), + in_specs=(P(), P('batch'), P(None)), + out_specs=(P('batch'), P(None)), ) - - model_fn_sharded = model_fn_partial return model_fn_sharded(params, augmented_and_preprocessed_input_batch, model_state,) diff --git a/pyproject.toml b/pyproject.toml index 0e36837de..2a82d9a52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] # Frameworks jax_core_deps = [ - "flax==0.8.4", + "flax==0.10.6", "optax==0.2.2", "chex==0.1.86", "ml_dtypes==0.5.1", From 5b7fb31423583956d7e503c836948dda6542329b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 27 May 2025 20:00:23 +0000 Subject: [PATCH 65/86] docker build script change --- docker/build_docker_images.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 645b81955..28ae028b9 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -11,10 +11,17 @@ do case "${flag}" in b) GIT_BRANCH=${OPTARG};; esac + case "${flag}" in + p) PROJECT=${OPTARG};; + esac done # Artifact repostiory -ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +if [ "$PROJECT" = "mlcommons-algoperf"]; then + ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +else + ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo" +fi if [[ -z ${GIT_BRANCH+x} ]] then From 57b8fe69d79c34f660d731836d7137c45801c172 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 27 May 2025 20:55:52 +0000 Subject: [PATCH 66/86] temporarily use pre-releases for jax install --- docker/Dockerfile | 4 ++-- pyproject.toml | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 32e4a8ddf..476956161 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -68,7 +68,7 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ - && pip install -e '.[jax_gpu]'; \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ @@ -78,7 +78,7 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[pytorch_gpu]' \ - && pip install -e '.[jax_gpu]'; \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ diff --git a/pyproject.toml b/pyproject.toml index 2a82d9a52..3e02f883b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,7 +110,12 @@ jax_cpu = [ "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax[cuda12]==0.6.0", + # Temporarily install with -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre + "jax", + "jaxlib", + "jax-cuda12-plugin[with-cuda]", + "jax-cuda12-pjrt" + # "jax[cuda12]==0.6.0", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] From e23e99acad1fadc88a33e04c126fa7f17e5d96cb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 28 May 2025 18:33:15 +0000 Subject: [PATCH 67/86] fix to pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3e02f883b..4dff6de4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,7 +114,7 @@ jax_gpu = [ "jax", "jaxlib", "jax-cuda12-plugin[with-cuda]", - "jax-cuda12-pjrt" + "jax-cuda12-pjrt", # "jax[cuda12]==0.6.0", "algoperf[jax_core_deps]", ] From 505fab2ea70b9fd0e1f34049a90ec0a487a9a052 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 May 2025 01:05:55 +0000 Subject: [PATCH 68/86] chnage defaults for job config script --- scoring/utils/slurm/make_job_config.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/scoring/utils/slurm/make_job_config.py b/scoring/utils/slurm/make_job_config.py index 116e70459..411aac939 100644 --- a/scoring/utils/slurm/make_job_config.py +++ b/scoring/utils/slurm/make_job_config.py @@ -13,11 +13,8 @@ from absl import flags import jax -SUBMISSION_PATH = 'submissions_algorithms/submissions/self_tuning/schedule_free_adamw_v2/submission.py' -EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2' -TUNING_SEARCH_SPACE = None -FRAMEWORK = 'pytorch' -TUNING_RULESET = 'self' +SUBMISSION_PATH = 'reference_algorithms/paper_baselines/adamw/jax/submission.py' +TUNING_SEARCH_SPACE = 'reference_algorithms/paper_baselines/adamw/tuning_search_space.json' flags.DEFINE_string( 'submission_path', @@ -29,17 +26,17 @@ 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.' ) flags.DEFINE_string('experiment_dir', - EXPERIMENT_DIR, + '$HOME/experiments/', 'Path to experiment dir where logs will be saved.') flags.DEFINE_enum( 'framework', - FRAMEWORK, + 'jax', enum_values=['jax', 'pytorch'], help='Can be either pytorch or jax.') flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.') flags.DEFINE_enum( 'tuning_ruleset', - TUNING_RULESET, + 'external', enum_values=['external', 'self'], help='Which tuning ruleset to score this submission on. Can be external or self.' ) From 4acaffe5c792302303149bbaf5f767f418b0b272 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 20:14:02 +0000 Subject: [PATCH 69/86] fix docker image --- docker/Dockerfile | 2 ++ docker/build_docker_images.sh | 20 +++++++++++++------- pyproject.toml | 2 +- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 476956161..72e3a810f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -23,6 +23,8 @@ RUN apt-get update && apt-get install -y \ libreadline-dev \ libffi-dev \ curl \ + liblzma-dev \ + libbz2-dev \ vim # Download and install Python 3.11 diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 28ae028b9..6b5e67ceb 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -1,34 +1,40 @@ #!/bin/bash # Bash script to build and push dev docker images to artifact repo # Usage: -# bash build_docker_images.sh -b +# bash build_docker_images.sh -b -f # Make program exit with non-zero exit code if any command fails. set -e -while getopts b: flag +while getopts "b:p:f:" flag; do case "${flag}" in b) GIT_BRANCH=${OPTARG};; - esac - case "${flag}" in p) PROJECT=${OPTARG};; + f) FRAMEWORK=${OPTARG};; esac done # Artifact repostiory -if [ "$PROJECT" = "mlcommons-algoperf"]; then +if [ "$PROJECT" = "mlcommons-algoperf" ]; then ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" else ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo" fi -if [[ -z ${GIT_BRANCH+x} ]] +if [[ -z ${GIT_BRANCH+x} ]]; then GIT_BRANCH='main' # Set default argument fi -for FRAMEWORK in "jax" "pytorch" "both" +FRAMEWORKS=( "jax" "pythorch" "both" ) + +if [[ -n "$FRAMEWORK" ]]; +then + FRAMEWORKS=("$FRAMEWORK") +fi + +for FRAMEWORK in "${FRAMEWORKS[@]}"; do IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" DOCKER_BUILD_COMMAND="docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH" diff --git a/pyproject.toml b/pyproject.toml index 4dff6de4b..3611e594b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "numpy>=2.0.2", "pandas>=2.0.1", "tensorflow==2.19.0", - "tensorflow-datasets==4.9.7", + "tensorflow-datasets==4.9.9", "tensorflow-probability==0.20.0", "tensorflow-addons==0.20.0", "gputil==1.4.0", From 3481f0e3ac8488057c9e165e09044b391258dd44 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 21:01:49 +0000 Subject: [PATCH 70/86] jax deprecation fix for jax.tree_map --- algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 0d6e8912d..45eb09a87 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -103,11 +103,11 @@ def init_model_fn( model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = jax.tree_map( + params = jax.tree.map( lambda x: jax.device_put(x, jax_sharding_utils.get_replicate_sharding()), params) - model_state = jax.tree_map( + model_state = jax.tree.map( lambda x: jax.device_put(x, jax_sharding_utils.get_replicate_sharding()), model_state) From 447d621fbdd105e63c9da0fde3ec3c8dec0a7f43 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 22:55:31 +0000 Subject: [PATCH 71/86] try to fix jax installation --- docker/Dockerfile | 4 ++-- pyproject.toml | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 72e3a810f..da8dbb131 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -70,7 +70,7 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ - && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ + && pip install -U -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ @@ -80,7 +80,7 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[pytorch_gpu]' \ - && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ + && pip install -U -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ diff --git a/pyproject.toml b/pyproject.toml index 3611e594b..172924edf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,6 @@ jax_gpu = [ "jaxlib", "jax-cuda12-plugin[with-cuda]", "jax-cuda12-pjrt", - # "jax[cuda12]==0.6.0", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] From a3df78c74b9881e4214dd5af09cc01ea6cf08863 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 23:22:10 +0000 Subject: [PATCH 72/86] temporary pip install change for jax gpu nightly --- docker/Dockerfile | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index da8dbb131..fb3ab9bdf 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -70,7 +70,9 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ - && pip install -U -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ + # Todo: remove temporary nightly install + # pip install -e '.[jax_gpu]' + && pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ @@ -80,7 +82,9 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[pytorch_gpu]' \ - && pip install -U -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ + # Todo: remove temporary nightly install + # pip install -e '.[jax_gpu]' + && pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ From 8aa3ffc825624063d8d0feb64f2c6c75521b3ec3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Jun 2025 01:52:58 +0000 Subject: [PATCH 73/86] add step_time to summary df --- scoring/score_submissions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index f07dc8cdd..154de6060 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -11,6 +11,7 @@ --compute_performance_profiles \ --output_dir scoring_results_self_tuning \ --self_tuning_ruleset + """ import operator @@ -99,6 +100,8 @@ def get_summary_df(workload, workload_df, include_test_split=False): lambda x: x['accumulated_submission_time'][int(x[ 'index to target on val'])] if x['val target reached'] else np.inf, axis=1) + + summary_df['step_time (s)'] = (workload_df['accumulated_submission_time'] / workload_df['global_step']).iloc[-1][-1] # test metrics if include_test_split: From 1cc068a1997c5b5e2a1021ec6e73bd87003ea4bd Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 20:19:14 +0000 Subject: [PATCH 74/86] capture trace --- submission_runner.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 468a04c7c..14841cade 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -162,6 +162,9 @@ 'Number of workers for ImageNet PyTorch evaluation data loaders.' 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' 'in incorrect evals currently, see issues/732.') +flags.DEFINE_boolean('capture_jax_trace', + False, + 'Captures jax profiler trace and writes to experiment directory.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -547,7 +550,8 @@ def score_submission_on_workload(workload: spec.Workload, save_checkpoints: Optional[bool] = True, hparam_start_index: Optional[bool] = None, hparam_end_index: Optional[bool] = None, - rng_seed: Optional[int] = None): + rng_seed: Optional[int] = None), + caputre_trace: Optional[bool] = False: # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -627,6 +631,9 @@ def score_submission_on_workload(workload: spec.Workload, tuning_search_space[hi] = hyperparameters with profiler.profile('Train'): + if capture_trace: + jax.profiler.start_trace(log_dir), + logging.info('Capturing and saving jax trace to {log_dir}') timing, metrics = train_once(workload, workload_name, global_batch_size, global_eval_batch_size, @@ -641,6 +648,8 @@ def score_submission_on_workload(workload: spec.Workload, max_global_steps, tuning_dir_name, save_checkpoints=save_checkpoints,) + if capture_trace: + jax.profiler.stop_trace() all_timings[hi] = timing all_metrics[hi] = metrics logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}') @@ -746,7 +755,8 @@ def main(_): save_checkpoints=FLAGS.save_checkpoints, hparam_start_index=FLAGS.hparam_start_index, hparam_end_index=FLAGS.hparam_end_index, - rng_seed=FLAGS.rng_seed) + rng_seed=FLAGS.rng_seed, + capture_trace=FLAGS.capture_jax_trace) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: From 274a9113e10e1d2871dfe4122ad4f51b4df315ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 20:38:23 +0000 Subject: [PATCH 75/86] add flag to skip evals --- submission_runner.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 14841cade..87eee0ead 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -165,6 +165,9 @@ flags.DEFINE_boolean('capture_jax_trace', False, 'Captures jax profiler trace and writes to experiment directory.') +flags.DEFINE_boolean('skip_evals', + False, + 'Skip evals on train eval, validation and test splits.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -212,7 +215,8 @@ def train_once( profiler: Profiler, max_global_steps: int = None, log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True + save_checkpoints: Optional[bool] = True, + skip_evals: Optional[bool] = False, ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) @@ -388,7 +392,7 @@ def train_once( # Check if submission is eligible for an untimed eval. if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + workload.eval_period_time_sec or train_state['training_complete']) and not skip_evals: # Prepare for evaluation (timed). if prepare_for_eval is not None: @@ -550,8 +554,9 @@ def score_submission_on_workload(workload: spec.Workload, save_checkpoints: Optional[bool] = True, hparam_start_index: Optional[bool] = None, hparam_end_index: Optional[bool] = None, - rng_seed: Optional[int] = None), - caputre_trace: Optional[bool] = False: + rng_seed: Optional[int] = None, + capture_trace: Optional[bool] = False, + skip_evals: Optional[bool] = False): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -647,7 +652,8 @@ def score_submission_on_workload(workload: spec.Workload, profiler, max_global_steps, tuning_dir_name, - save_checkpoints=save_checkpoints,) + save_checkpoints=save_checkpoints, + skip_evals=skip_evals) if capture_trace: jax.profiler.stop_trace() all_timings[hi] = timing @@ -756,7 +762,8 @@ def main(_): hparam_start_index=FLAGS.hparam_start_index, hparam_end_index=FLAGS.hparam_end_index, rng_seed=FLAGS.rng_seed, - capture_trace=FLAGS.capture_jax_trace) + capture_trace=FLAGS.capture_jax_trace, + skip_evals=FLAGS.skip_evals,) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: From 2580f5cd3053af776b7feb51835433585944aee8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 01:11:42 +0000 Subject: [PATCH 76/86] add log dir to save traces to --- submission_runner.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 87eee0ead..b821f83dc 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -637,8 +637,8 @@ def score_submission_on_workload(workload: spec.Workload, with profiler.profile('Train'): if capture_trace: - jax.profiler.start_trace(log_dir), - logging.info('Capturing and saving jax trace to {log_dir}') + logging.info(f'Capturing and saving jax trace to {log_dir}') + jax.profiler.start_trace(f'{log_dir}/traces'), timing, metrics = train_once(workload, workload_name, global_batch_size, global_eval_batch_size, @@ -680,12 +680,17 @@ def score_submission_on_workload(workload: spec.Workload, logging.info(f'Creating directory at {log_dir}.') logger_utils.makedir(log_dir) with profiler.profile('Train'): + if capture_trace: + jax.profiler.start_trace('/algoperf/traces'), + logging.info(f'Capturing and saving jax trace to {log_dir}') score, _ = train_once( workload, workload_name, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, prepare_for_eval, None, rng_seed, rng, profiler, max_global_steps, log_dir, save_checkpoints=save_checkpoints) + if capture_trace: + jax.profiler.stop_trace() return score From 00d3810be65c0875d6aa2a58b8d078973113b5ca Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 01:20:52 +0000 Subject: [PATCH 77/86] remove editable flag from docker install for ml packages --- docker/Dockerfile | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 76bc5cfe0..4d0163b03 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -71,18 +71,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ + && pip install '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ + && pip install '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_cpu]' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install '.[jax_cpu]' \ + && pip install '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ + && pip install '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ From c87d90845f4db476952e88fbdca4fd4c9206a4c8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 01:26:44 +0000 Subject: [PATCH 78/86] add cpu version for pytorch package to pyproject.toml --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 172924edf..694a135c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,7 +117,10 @@ jax_gpu = [ "jax-cuda12-pjrt", "algoperf[jax_core_deps]", ] -pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] +pytorch_cpu = [ + "torch==2.5.1+cpu", + "torchvision==0.20.1+cpu" +] pytorch_gpu = [ "torch==2.5.1", "torchvision==0.20.1", From f387724ca69e62626cada9a8deb07fcc81e10586 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 05:45:24 +0000 Subject: [PATCH 79/86] decrease logging frequency --- reference_algorithms/paper_baselines/adamw/jax/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 9e85d3d84..0479746f1 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -174,7 +174,7 @@ def update_params( label_smoothing) # Log loss, grad_norm. - if global_step % 1 == 0 and workload.metrics_logger is not None: + if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { 'loss': loss.item(), From f4c60720622e0ea02ca30262ab1928030fcac66c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 06:02:51 +0000 Subject: [PATCH 80/86] fix pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b438b199e..8d682b787 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,8 +117,8 @@ jax_gpu = [ "algoperf[jax_core_deps]", ] pytorch_cpu = [ - "torch==2.5.1+cpu", - "torchvision==0.20.1+cpu" + "torch==2.5.1, + "torchvision==0.20.1" ] pytorch_gpu = [ "torch==2.5.1", From 8616a64e9465db36c039679329c06e0b2a1f9dfc Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 06:10:26 +0000 Subject: [PATCH 81/86] fix --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8d682b787..2d6f7dbcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,7 +117,7 @@ jax_gpu = [ "algoperf[jax_core_deps]", ] pytorch_cpu = [ - "torch==2.5.1, + "torch==2.5.1", "torchvision==0.20.1" ] pytorch_gpu = [ From 89ddb7fd8b0c33bfda488d6a0a3ab049a7b1143f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 20:15:09 +0000 Subject: [PATCH 82/86] update dockerfile --- README.md | 2 +- algoperf/logger_utils.py | 1 - docker/Dockerfile | 7 +++---- pyproject.toml | 4 +--- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 4d530b3cd..a7a489822 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ Both options are described in detail in the [**Getting Started**](/docs/GETTING_ *TL;DR to install the Jax version for GPU run:* ```bash -pip3 install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu +pip3 install -e '.[pytorch_cpu]' --extra-index-url https://download.pytorch.org/whl/cpu pip3 install -e '.[jax_gpu]' pip3 install -e '.[full]' ``` diff --git a/algoperf/logger_utils.py b/algoperf/logger_utils.py index c988956dc..dd6a3d785 100644 --- a/algoperf/logger_utils.py +++ b/algoperf/logger_utils.py @@ -226,7 +226,6 @@ def _get_system_software_info() -> Dict: return system_software_info - def _get_git_commit_hash() -> str: return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() diff --git a/docker/Dockerfile b/docker/Dockerfile index fb3ab9bdf..6328f7771 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -69,15 +69,14 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ + && pip install -e '.[pytorch_cpu]' --extra-index-url https://download.pytorch.org/whl/cpu \ # Todo: remove temporary nightly install # pip install -e '.[jax_gpu]' - && pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/; \ + && c; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[pytorch_gpu]' \ - && pip install -e '.[jax_cpu]'; \ + && pip install -e '.[pytorch_gpu, jax_cpu]'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ diff --git a/pyproject.toml b/pyproject.toml index 2d6f7dbcf..9054ab44f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,7 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", + "Programming Languate :: Python :: 3.11", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ From 993fe6f1b0f5f5aa538fc825dbc3f022b6180a6e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 21:50:16 +0000 Subject: [PATCH 83/86] update installation instructions --- README.md | 12 ++++-------- docker/Dockerfile | 18 ++++-------------- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index a7a489822..1046dbd0a 100644 --- a/README.md +++ b/README.md @@ -57,20 +57,16 @@ You can install this package and dependencies in a [Python virtual environment]( We recommend using a Docker container (or alternatively, a Singularity/Apptainer container) to ensure a similar environment to our scoring and testing environments. Both options are described in detail in the [**Getting Started**](/docs/GETTING_STARTED.md) document. -*TL;DR to install the Jax version for GPU run:* +*TL;DR to install the Jax version for GPU and all workload dependencies run:* ```bash -pip3 install -e '.[pytorch_cpu]' --extra-index-url https://download.pytorch.org/whl/cpu -pip3 install -e '.[jax_gpu]' -pip3 install -e '.[full]' +pip3 install -e '.[pytorch_cpu,jax_gpu,full]' --extra-index-url https://download.pytorch.org/whl/cpu ``` -*TL;DR to install the PyTorch version for GPU run:* +*TL;DR to install the PyTorch version for GPU and all workload dependencies run:* ```bash -pip3 install -e '.[jax_cpu]' -pip3 install -e '.[pytorch_gpu]' -pip3 install -e '.[full]' +pip3 install -e '.[jax_cpu,pytorch_gpu,full]' ``` ## Getting Started diff --git a/docker/Dockerfile b/docker/Dockerfile index 6328f7771..4879d9612 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -69,28 +69,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[pytorch_cpu]' --extra-index-url https://download.pytorch.org/whl/cpu \ + && pip install -e '.[pytorch_cpu, full]' --extra-index-url https://download.pytorch.org/whl/cpu \ # Todo: remove temporary nightly install - # pip install -e '.[jax_gpu]' - && c; \ + && pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[pytorch_gpu, jax_cpu]'; \ - elif [ "$framework" = "both" ] ; then \ - echo "Installing Jax GPU and Pytorch GPU" \ - && cd /algorithmic-efficiency \ - && pip install -e '.[pytorch_gpu]' \ - # Todo: remove temporary nightly install - # pip install -e '.[jax_gpu]' - && pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/; \ + && pip install -e '.[pytorch_gpu, jax_cpu, full]'; \ else \ - echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ + echo "Invalid build-arg $framework: framework should be either jax or pytorch." >&2 \ && exit 1 ; \ fi -RUN cd /algorithmic-efficiency && pip install -e '.[full]' - RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull From 20b726ad3dafcfff30f314e85cdc94d06312cff4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 2 Jul 2025 19:33:39 +0000 Subject: [PATCH 84/86] use jraph.batch_np instead of jraph.batch since jraph.batch with jnp leaks device memory --- algoperf/workloads/ogbg/input_pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py index d2b0e874b..643e8077b 100644 --- a/algoperf/workloads/ogbg/input_pipeline.py +++ b/algoperf/workloads/ogbg/input_pipeline.py @@ -148,8 +148,10 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): weights_shards.append(weights) if count == num_shards: + # jraph.batch has a memory leak and OOMs + # with jraph.batch_np we may have transferred the leak to the cpu.. yield { - 'inputs': jraph.batch(graphs_shards), + 'inputs': jraph.batch_np(graphs_shards), 'targets': np.vstack(labels_shards), 'weights': np.vstack(weights_shards) } From 9d1f91550ea645523f61b888370e814b4feb0b99 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 2 Jul 2025 19:33:49 +0000 Subject: [PATCH 85/86] modify documentation --- algoperf/workloads/ogbg/input_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py index 643e8077b..6b2f784ae 100644 --- a/algoperf/workloads/ogbg/input_pipeline.py +++ b/algoperf/workloads/ogbg/input_pipeline.py @@ -149,7 +149,8 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): if count == num_shards: # jraph.batch has a memory leak and OOMs - # with jraph.batch_np we may have transferred the leak to the cpu.. + # It is possible with jraph.batch_np we may have transferred the leak + # to the cpu. yield { 'inputs': jraph.batch_np(graphs_shards), 'targets': np.vstack(labels_shards), From 3486145e742eca0c61f596fbe1a3c790fc4aadd6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 3 Jul 2025 03:04:44 +0000 Subject: [PATCH 86/86] add plot util to visualize training metrics with wandb --- scoring/plot_utils/plot_curves.py | 46 ++++++++++++++++++++++++++ scoring/utils/slurm/make_job_config.py | 20 +++++++++-- 2 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 scoring/plot_utils/plot_curves.py diff --git a/scoring/plot_utils/plot_curves.py b/scoring/plot_utils/plot_curves.py new file mode 100644 index 000000000..96b5305eb --- /dev/null +++ b/scoring/plot_utils/plot_curves.py @@ -0,0 +1,46 @@ +from absl import flags +from absl import app + +import matplotlib.pyplot as plt +import pandas as pd +import os +import wandb + +flags.DEFINE_string('trial_dir', None, 'Path to trial dir') +flags.DEFINE_string() +FLAGS = flags.FLAGS + +MEASUREMENTS_FILENAME = 'eval_measurements.csv' + +def get_filename(trial_dir): + filename = os.path.join(trial_dir, MEASUREMENTS_FILENAME ) + return filename + +def main(_): + filename = get_filename(FLAGS.trial_dir) + + # Start a new W&B run --- + run = wandb.init(project="visualize-training-curve", name="conformer_jit") + + # Log the CSV as a versioned Artifact --- + artifact = wandb.Artifact(name="training-data", type="dataset") + artifact.add_file(filename) # Directly add the file + run.log_artifact(artifact) + + # Log the metrics for direct visualization --- + df = pd.read_csv(filename) + print(df.columns) + for index, row in df.iterrows(): + metrics = {col : row[col] for col in df.columns} + wandb.log(metrics, step=int(row["global_step"])) + + # Finish the W&B run --- + run.finish() + + + return + + +if __name__ == '__main__': + + app.run(main) \ No newline at end of file diff --git a/scoring/utils/slurm/make_job_config.py b/scoring/utils/slurm/make_job_config.py index 411aac939..04affb852 100644 --- a/scoring/utils/slurm/make_job_config.py +++ b/scoring/utils/slurm/make_job_config.py @@ -15,6 +15,8 @@ SUBMISSION_PATH = 'reference_algorithms/paper_baselines/adamw/jax/submission.py' TUNING_SEARCH_SPACE = 'reference_algorithms/paper_baselines/adamw/tuning_search_space.json' +NUM_TUNING_TRIALS = 3 # For external tuning ruleset +NUM_STUDIES = 3 flags.DEFINE_string( 'submission_path', @@ -40,13 +42,21 @@ enum_values=['external', 'self'], help='Which tuning ruleset to score this submission on. Can be external or self.' ) +flags.DEFINE_string( + 'workloads', + None, + help='Comma seperated list of workloads to run.' +) +flags.DEFINE_integer( + 'num_studies', + NUM_STUDIES, + help='Number of studies.' +) FLAGS = flags.FLAGS MIN_INT = -2**(31) MAX_INT = 2**(31) - 1 -NUM_TUNING_TRIALS = 5 # For external tuning ruleset -NUM_STUDIES = 3 WORKLOADS = { "imagenet_resnet": {"dataset": "imagenet"}, @@ -61,7 +71,11 @@ def main(_): - workloads = WORKLOADS.keys() + if not FLAGS.workloads: + workloads = WORKLOADS.keys() + else: + workloads = FLAGS.workloads.split(',') + key = jax.random.key(FLAGS.seed) jobs = []