diff --git a/README.md b/README.md index 8e470266d..1046dbd0a 100644 --- a/README.md +++ b/README.md @@ -57,20 +57,16 @@ You can install this package and dependencies in a [Python virtual environment]( We recommend using a Docker container (or alternatively, a Singularity/Apptainer container) to ensure a similar environment to our scoring and testing environments. Both options are described in detail in the [**Getting Started**](/docs/GETTING_STARTED.md) document. -*TL;DR to install the Jax version for GPU run:* +*TL;DR to install the Jax version for GPU and all workload dependencies run:* ```bash -pip3 install -e '.[pytorch_cpu]' -pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' -pip3 install -e '.[full]' +pip3 install -e '.[pytorch_cpu,jax_gpu,full]' --extra-index-url https://download.pytorch.org/whl/cpu ``` -*TL;DR to install the PyTorch version for GPU run:* +*TL;DR to install the PyTorch version for GPU and all workload dependencies run:* ```bash -pip3 install -e '.[jax_cpu]' -pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121' -pip3 install -e '.[full]' +pip3 install -e '.[jax_cpu,pytorch_gpu,full]' ``` ## Getting Started diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index f4cb6c2db..75d8d59ea 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -11,7 +11,6 @@ from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint -import jax import numpy as np from tensorflow.io import gfile # pytype: disable=import-error import torch @@ -193,10 +192,7 @@ def save_checkpoint(framework: str, train_state, eval_results, global_step, preemption_count). """ if framework == 'jax': - model_params = jax.device_get(jax_utils.unreplicate(model_params)) opt_state, _ = optimizer_state - opt_state = jax.device_get(jax_utils.unreplicate(opt_state)) - model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: if isinstance( model_params, diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..9a7b91b15 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -11,6 +11,7 @@ from torch.utils.data import DistributedSampler from torch.utils.data import Sampler +from algoperf import jax_sharding_utils from algoperf import spec @@ -60,10 +61,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - # Reshape (global_batch_size, ...) to - # (local_device_count, per_device_batch_size, ...). - # Assumes that `global_batch_size % local_device_count == 0`. - return x.reshape((local_device_count, -1, *x.shape[1:])) + return jax.device_put(x, jax_sharding_utils.get_batch_dim_sharding()) return jax.tree.map(_prepare, batch) diff --git a/algoperf/jax_sharding_utils.py b/algoperf/jax_sharding_utils.py new file mode 100644 index 000000000..6c90c5cd7 --- /dev/null +++ b/algoperf/jax_sharding_utils.py @@ -0,0 +1,37 @@ +"""Utilities for dealing with sharding in JAX.""" + +import jax +from jax.sharding import NamedSharding, PartitionSpec as P + + +def get_replicate_sharding(): + """Returns a sharding spec that replicates data across all devices.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return NamedSharding(mesh, P()) + + +def get_batch_dim_sharding(): + """Returns a sharding spec that shards data along the first axis.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return NamedSharding(mesh, P('batch')) + + +def shard_along_batch_dim(x): + """Shards a tensor across all devices.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return jax.tree.map( + lambda x: jax.device_put(x, NamedSharding(mesh, P('batch'))), x) + + +def replicate(x): + """Replicates tensor across all devices.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return jax.tree.map( + lambda x: jax.device_put(x, NamedSharding(mesh, P())), x) + + +def display_shard_info(x: jax.Array): + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n") \ No newline at end of file diff --git a/algoperf/logger_utils.py b/algoperf/logger_utils.py index c988956dc..dd6a3d785 100644 --- a/algoperf/logger_utils.py +++ b/algoperf/logger_utils.py @@ -226,7 +226,6 @@ def _get_system_software_info() -> Dict: return system_software_info - def _get_git_commit_hash() -> str: return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 728d05f29..8eec88f28 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -171,5 +171,5 @@ def create_input_iter( functools.partial( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) - it = jax_utils.prefetch_to_device(it, 2) + return it diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index ad43bc62f..48cee94b4 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -3,7 +3,6 @@ import functools from typing import Any, Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax import linen as nn from flax.core import pop import jax @@ -13,6 +12,7 @@ import tensorflow_datasets as tfds from algoperf import param_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.cifar.cifar_jax import models from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter @@ -31,6 +31,7 @@ def _build_cifar_dataset( repeat_final_dataset: Optional[bool] = None ) -> Iterator[Dict[str, spec.Tensor]]: ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) + ds_builder.download_and_prepare() train = split == 'train' assert self.num_train_examples + self.num_validation_examples == 50000 if split in ['train', 'eval_train']: @@ -96,8 +97,8 @@ def init_model_fn( model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + model_state = jax_sharding_utils.replicate(params) + params = jax_sharding_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -175,35 +176,51 @@ def _compute_metrics(self, 'loss': summed_loss, 'accuracy': accuracy, } - metrics = lax.psum(metrics, axis_name='batch') return metrics - @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) - weights = batch.get('weights') - if weights is None: - weights = jnp.ones(len(logits)) - return self._compute_metrics(logits, batch['targets'], weights) + + @functools.partial( + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_batch_dim_sharding(), # rng + ), + ) + def _eval_model_jitted( + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + """Return the mean accuracy and loss as a dict.""" + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False) + weights = batch.get('weights') + if weights is None: + weights = jnp.ones(len(logits)) + return self._compute_metrics(logits, batch['targets'], weights) + + metrics = _eval_model_jitted(params, + batch, + model_state, + rng) + return jax.tree.map(lambda x: x.item(), metrics) def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree_map(lambda x: x / num_examples, total_metrics) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 91761e458..723326120 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -11,6 +11,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.workloads.criteo1tb.criteo1tb_jax import models +from algoperf import jax_sharding_utils from algoperf.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload @@ -105,7 +106,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return jax_sharding_utils.replicate(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_7' @@ -129,13 +130,16 @@ def model_fn( return logits_batch, None @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_batch_pmapped(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + ), + static_argnums=(0,), + out_shardings=jax_sharding_utils.get_replicate_sharding()) + def _eval_batch_jitted(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor]) -> spec.Tensor: logits, _ = self.model_fn( params, batch, @@ -156,8 +160,7 @@ def _eval_batch(self, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return np.array( - self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) + return np.array(self._eval_batch_jitted(params, batch), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 1156cf30a..1349cef64 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -10,6 +10,7 @@ from algoperf import param_utils from algoperf import spec +from algoperf import jax_sharding_utils import algoperf.random_utils as prng from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim @@ -39,7 +40,7 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = jax_utils.replicate(params) + params = jax_sharding_utils.replicate(params) return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -94,10 +95,12 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + jax_sharding_utils.get_replicate_sharding()), + static_argnums=(0,), + out_shardings=jax_sharding_utils.get_replicate_sharding()) def _eval_model(self, params: spec.Tensor, batch: Dict[str, spec.Tensor], @@ -126,7 +129,6 @@ def _eval_model(self, 'ssim': ssim_sum, 'loss': summed_loss, } - metrics = jax.lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -154,13 +156,12 @@ def _eval_model_on_split(self, num_batches=num_batches) total_metrics = {'ssim': 0., 'loss': 0.} - eval_rngs = prng.split(model_rng, jax.local_device_count()) for _ in range(num_batches): batch = next(self._eval_iters[split]) # We already sum these metrics across devices inside _eval_model. - synced_metrics = self._eval_model(params, batch, eval_rngs) + synced_metrics = self._eval_model(params, batch, model_rng) total_metrics = { - k: v + synced_metrics[k][0] for k, v in total_metrics.items() + k: v + synced_metrics[k] for k, v in total_metrics.items() } return {k: float(v.item() / num_examples) for k, v in total_metrics.items()} diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 66105335b..35bc3635c 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -399,6 +399,7 @@ def create_input_iter(split: str, ds) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. - it = jax_utils.prefetch_to_device(it, 2) + # TODO (kasimbeg): put on device + # it = jax_utils.prefetch_to_device(it, 2) return iter(it) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 4ec3937b8..45eb09a87 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -20,6 +20,7 @@ from algoperf import param_utils from algoperf import random_utils as prng +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline @@ -71,17 +72,6 @@ def _build_dataset( use_randaug=use_randaug) return ds - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() # Create a shallow copy - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, @@ -113,18 +103,29 @@ def init_model_fn( model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + params = jax.tree.map( + lambda x: jax.device_put(x, + jax_sharding_utils.get_replicate_sharding()), + params) + model_state = jax.tree.map( + lambda x: jax.device_put(x, + jax_sharding_utils.get_replicate_sharding()), + model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_replicate_sharding(), # rng + ), + static_argnums=(0,), + out_shardings=jax_sharding_utils.get_replicate_sharding()) def _eval_model(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -218,7 +219,6 @@ def _compute_metrics(self, 'loss': summed_loss, 'accuracy': accuracy, } - metrics = lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -231,9 +231,6 @@ def _eval_model_on_split(self, data_dir: str, global_step: int = 0) -> Dict[str, float]: del global_step - if model_state is not None: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) num_batches = int(math.ceil(num_examples / global_batch_size)) data_rng, eval_rng = prng.split(rng, 2) # We already repeat the dataset indefinitely in tf.data. @@ -250,20 +247,14 @@ def _eval_model_on_split(self, eval_metrics = {} for bi in range(num_batches): eval_rng = prng.fold_in(eval_rng, bi) - step_eval_rngs = prng.split(eval_rng, jax.local_device_count()) batch = next(self._eval_iters[split]) - # We already average these metrics across devices inside _compute_metrics. - synced_metrics = self._eval_model(params, - batch, - model_state, - step_eval_rngs) + synced_metrics = self._eval_model(params, batch, model_state, eval_rng) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), - eval_metrics) + eval_metrics = jax.tree.map(lambda x: x / num_examples, eval_metrics) return eval_metrics diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 35a6c46be..c4d823319 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -2,13 +2,13 @@ from typing import Dict, Optional, Tuple -from flax import jax_utils from flax import linen as nn from flax.core import pop import jax import jax.numpy as jnp from algoperf import param_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload @@ -46,8 +46,8 @@ def init_model_fn( params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + params = jax_sharding_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 39012a20d..758344a23 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -2,11 +2,9 @@ import math from typing import Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax.core import pop import flax.linen as nn import jax -from jax import lax import jax.numpy as jnp import numpy as np import optax @@ -14,6 +12,7 @@ from algoperf import data_utils from algoperf import param_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.librispeech_conformer import metrics from algoperf.workloads.librispeech_conformer import workload @@ -93,8 +92,11 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + + # Add sharding + params = jax_sharding_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) + return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -180,6 +182,7 @@ def _build_input_queue( 'targets': (targets.numpy(), target_paddings.numpy()), } + # Use data_utils.shard_and_maybe_pad_np to handle sharding padded_batch = data_utils.shard_and_maybe_pad_np( numpy_batch, padding_value=1.0) yield padded_batch @@ -305,11 +308,16 @@ def greedy_decode( return hyp, hyp_paddings @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) - def eval_step_pmapped( + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_replicate_sharding(), # rng + ), + out_shardings=jax_sharding_utils.get_batch_dim_sharding(), + static_argnums=(0,)) + def _eval_step( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -325,15 +333,45 @@ def eval_step_pmapped( decoded, decoded_paddings = self.greedy_decode(logits, logit_paddings) loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) - targets, target_paddings = batch['targets'] - return self.metrics_bundle.gather_from_model_output( - loss_dict=loss, - decoded=decoded, - decoded_paddings=decoded_paddings, - targets=targets, - target_paddings=target_paddings, - axis_name='batch') + # Convert metrics bundle to dictionary + metrics_dict = { + 'loss_per_example': + loss['per_example'], + 'decoded': + decoded, + 'decoded_paddings': + decoded_paddings, + 'targets': + targets, + 'target_paddings': + target_paddings, + 'n_valid_examples': + jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] + } + return metrics_dict + + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState): + """Evaluates the model and returns a metrics bundle.""" + metrics_dict = self._eval_step(params, batch, model_state, rng) + + # Convert dictionary back to metrics bundle + metrics_bundle = self.metrics_bundle.single_from_model_output( + loss_dict={ + 'summed': metrics_dict['loss_per_example'].sum(), + 'per_example': metrics_dict['loss_per_example'], + 'n_valid_examples': metrics_dict['n_valid_examples'].sum() + }, + decoded=metrics_dict['decoded'], + decoded_paddings=metrics_dict['decoded_paddings'], + targets=metrics_dict['targets'], + target_paddings=metrics_dict['target_paddings']) + + return metrics_bundle def _eval_model_on_split(self, split: str, @@ -346,9 +384,6 @@ def _eval_model_on_split(self, global_step: int = 0) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step - if model_state is not None and len(model_state) > 0: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: @@ -358,10 +393,7 @@ def _eval_model_on_split(self, metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step_pmapped(params, - eval_batch, - model_state, - rng).unreplicate() + computed_metrics = self.eval_step(params, eval_batch, model_state, rng) if metrics_report is None: metrics_report = computed_metrics @@ -373,16 +405,6 @@ def _eval_model_on_split(self, return computed_metrics - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - class LibriSpeechConformerAttentionTemperatureWorkload( LibriSpeechConformerWorkload): diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b116f44cd..2c7011445 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -8,13 +8,20 @@ # webpage : https://bastings.github.io/ """ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Type, Mapping, Sequence +from absl import logging +import numpy as np + +import functools +import flax from flax import linen as nn from flax import struct import jax from jax.experimental import rnn import jax.numpy as jnp +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor @@ -310,16 +317,12 @@ def __call__(self, inputs, input_paddings=None, train=False): count_v = jnp.sum( jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) - sum_v = jax.lax.psum(sum_v, axis_name='batch') - count_v = jax.lax.psum(count_v, axis_name='batch') - count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v variance = (inputs - mean) * (inputs - mean) * mask sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True) - sum_vv = jax.lax.psum(sum_vv, axis_name='batch') var = sum_vv / count_v self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d3b616f43..9d177eeba 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -4,10 +4,13 @@ from flax import jax_utils import jax import jax.numpy as jnp +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P import numpy as np from algoperf import param_utils from algoperf import spec +from algoperf import jax_sharding_utils from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models @@ -41,21 +44,22 @@ def init_model_fn( fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) + # model_init_fn = functools.partial(self._model.init, train=False) params_rng, dropout_rng = jax.random.split(rng, 2) variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state = variables[ - 'batch_stats'] if not self.layernorm_everywhere else {} + model_state = {'batch_stats': variables[ + 'batch_stats']} if not self.layernorm_everywhere else {} params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) + params = jax_sharding_utils.replicate(params) return params, model_state - - def model_fn( + + def model_fn_ref( self, params: spec.ParameterContainer, augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], @@ -76,6 +80,8 @@ def model_fn( train=True, rngs={'dropout' : rng}, mutable=['batch_stats']) + if 'batch_stats' in new_model_state and new_model_state['batch_stats']: + new_model_state = jax.lax.pmean(new_model_state, 'batch') return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( @@ -86,6 +92,32 @@ def model_fn( mutable=False) return (logits, logit_paddings), model_state + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + + model_fn_partial = jax.tree_util.Partial(self.model_fn_ref, + mode=mode, + rng=rng, + update_batch_norm=update_batch_norm, + use_running_average_bn=use_running_average_bn) + + model_fn_sharded = shard_map(model_fn_partial, + jax.sharding.Mesh(jax.devices(), ('batch')), + in_specs=(P(), P('batch'), P(None)), + out_specs=(P('batch'), P(None)), + ) + return model_fn_sharded(params, + augmented_and_preprocessed_input_batch, + model_state,) + def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5a4382da1..ad2d7fc8a 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -6,11 +6,11 @@ from flax import jax_utils from flax import linen as nn import jax -from jax import lax import jax.numpy as jnp import optax from algoperf import param_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.mnist.workload import BaseMnistWorkload @@ -46,7 +46,7 @@ def init_model_fn( train=True)['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_1' @@ -101,10 +101,14 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_batch_dim_sharding(), # rng + ), + static_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, @@ -125,11 +129,10 @@ def _eval_model( (jnp.argmax(logits, axis=-1) == batch['targets']) * weights) summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed'] metrics = {'accuracy': accuracy, 'loss': summed_loss} - metrics = lax.psum(metrics, axis_name='batch') return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x.item() / num_examples), total_metrics) diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py index 3cb6f51de..6b2f784ae 100644 --- a/algoperf/workloads/ogbg/input_pipeline.py +++ b/algoperf/workloads/ogbg/input_pipeline.py @@ -148,17 +148,13 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): weights_shards.append(weights) if count == num_shards: - - def f(x): - return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) - - graphs_shards = f(graphs_shards) - labels_shards = f(labels_shards) - weights_shards = f(weights_shards) + # jraph.batch has a memory leak and OOMs + # It is possible with jraph.batch_np we may have transferred the leak + # to the cpu. yield { - 'inputs': graphs_shards, - 'targets': labels_shards, - 'weights': weights_shards, + 'inputs': jraph.batch_np(graphs_shards), + 'targets': np.vstack(labels_shards), + 'weights': np.vstack(weights_shards) } count = 0 diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 0e66d2ab8..7d7de1ecb 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -2,6 +2,7 @@ # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. from typing import Optional, Tuple +import jax from flax import linen as nn import jax.numpy as jnp import jraph diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e895d15a7..abfd70504 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -8,6 +8,7 @@ import jraph import optax +from algoperf import jax_sharding_utils from algoperf import param_utils from algoperf import spec from algoperf.workloads.ogbg import metrics @@ -45,7 +46,8 @@ def init_model_fn( params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(params), None + params = jax_sharding_utils.replicate(params) + return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_17' @@ -106,11 +108,16 @@ def _eval_metric(self, labels, logits, masks): return metrics.EvalMetrics.single_from_model_output( loss=loss['per_example'], logits=logits, labels=labels, mask=masks) + @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_replicate_sharding()), + static_argnums=(0,), + out_shardings=jax_sharding_utils.get_replicate_sharding(), + ) def _eval_batch(self, params, batch, model_state, rng): return super()._eval_batch(params, batch, model_state, rng) @@ -119,7 +126,6 @@ def _normalize_eval_metrics( Any]) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples - total_metrics = total_metrics.reduce() return {k: float(v) for k, v in total_metrics.compute().items()} diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 971e7f0f6..45ea778fd 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -161,6 +161,7 @@ def _eval_batch(self, spec.ForwardPassMode.EVAL, rng, update_batch_norm=False) + # jax.debug.print(str(logits)) return self._eval_metric(batch['targets'], logits, batch['weights']) def _eval_model_on_split(self, diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py index ad314a7d3..5e175320a 100644 --- a/algoperf/workloads/wmt/bleu.py +++ b/algoperf/workloads/wmt/bleu.py @@ -283,8 +283,7 @@ def ref_stats(output, refs): closest_diff = diff closest_len = reflen elif diff == closest_diff: - if reflen < closest_len: - closest_len = reflen + closest_len = min(reflen, closest_len) ngrams_ref = extract_ngrams(ref) for ngram in ngrams_ref: diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cdfcb91df..36f5b8606 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -14,6 +14,7 @@ import optax from algoperf import param_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_jax import decode @@ -69,10 +70,16 @@ def compute_weighted_cross_entropy( } @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) - def eval_step_pmapped( - self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + ), + static_argnums=(0,), # self + ) + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: """Calculate evaluation metrics on a batch.""" inputs = batch['inputs'] targets = batch['targets'] @@ -90,29 +97,29 @@ def eval_step_pmapped( 'denominator': weight_sum, } - def eval_step(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: - replicated_eval_metrics = self.eval_step_pmapped(params, batch) - return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) - @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + jax_sharding_utils.get_batch_dim_sharding(), # inputs + ), + static_argnums=( + 0, + 2, + )) def initialize_cache(self, inputs: spec.Tensor, max_decode_len: int = 256) -> Dict[str, spec.Tensor]: """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] - initial_variables = models.Transformer(config).init( - jax.random.PRNGKey(0), - jnp.ones(inputs.shape, jnp.float32), + dummy_inputs = jax_sharding_utils.shard_along_batch_dim( + jnp.ones(inputs.shape, jnp.float32)) + dummy_targets = jax_sharding_utils.shard_along_batch_dim( jnp.ones(target_shape, jnp.float32)) + initial_variables = models.Transformer(config).init( + jax.random.PRNGKey(0), dummy_inputs, dummy_targets) return initial_variables['cache'] - # eos_id, max_decode_len are constant. - @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5)) def predict_step(self, inputs: spec.Tensor, params: spec.ParameterContainer, @@ -180,20 +187,35 @@ def translate_and_calculate_bleu(self, """Translates the `predict_ds` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] + jitted_predict_step = None for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - predicted = self.predict_step(pred_batch['inputs'], - params, - cache, - decode.EOS_ID, - max_predict_length) - predicted = _to_host(predicted) - targets = _to_host(pred_batch['targets']) + if jitted_predict_step is None: + jitted_predict_step = jax.jit( + self.predict_step, + in_shardings=( + jax_sharding_utils.get_batch_dim_sharding(), # inputs + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_replicate_sharding(), # cache + ), + static_argnums=( + 3, # eos_id + 4, # max_decode_len, + 5, # beam_size + )) + predicted = jitted_predict_step(pred_batch['inputs'], + params, + cache, + decode.EOS_ID, + max_predict_length) + # predicted = _to_host(predicted) + # targets = _to_host(pred_batch['targets']) + targets = pred_batch['targets'] # Find actual batch size, ignoring the potential padding. weights = pred_batch.get('weights') if weights is not None: - weights = _to_host(weights) + # weights = _to_host(weights) actual_batch_size = int(weights.sum(0)[0].item()) else: actual_batch_size = len(predicted) @@ -213,7 +235,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is used as attention_dropout_rate.""" - init_fake_batch_size = 2 + init_fake_batch_size = 8 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -235,15 +257,21 @@ def init_model_fn( eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) params_rng, dropout_rng = jax.random.split(rng) + inputs = jnp.ones(input_shape, jnp.float32) + targets = jnp.ones(target_shape, jnp.float32) + sharded_inputs = jax_sharding_utils.shard_along_batch_dim(inputs) + sharded_targets = jax_sharding_utils.shard_along_batch_dim(targets) + initial_variables = jax.jit( self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + sharded_inputs, + sharded_targets) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + params = jax_sharding_utils.shard_along_batch_dim(initial_params) + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'shared_embedding' diff --git a/docker/Dockerfile b/docker/Dockerfile index 76bc5cfe0..4879d9612 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docker build -t --build-arg framework=pytorch # To build Docker image -FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 +FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu20.04 # Installing machine packages RUN echo "Setting up machine" @@ -23,8 +23,8 @@ RUN apt-get update && apt-get install -y \ libreadline-dev \ libffi-dev \ curl \ - libbz2-dev \ liblzma-dev \ + libbz2-dev \ vim # Download and install Python 3.11 @@ -56,8 +56,6 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ -RUN pip install --upgrade pip - # Install Algorithmic efficiency repo RUN pip install --upgrade pip @@ -71,25 +69,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ + && pip install -e '.[pytorch_cpu, full]' --extra-index-url https://download.pytorch.org/whl/cpu \ + # Todo: remove temporary nightly install + && pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_cpu]' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ - elif [ "$framework" = "both" ] ; then \ - echo "Installing Jax GPU and Pytorch GPU" \ - && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install -e '.[pytorch_gpu, jax_cpu, full]'; \ else \ - echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ + echo "Invalid build-arg $framework: framework should be either jax or pytorch." >&2 \ && exit 1 ; \ fi -RUN cd /algorithmic-efficiency && pip install -e '.[full]' - RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 645b81955..6b5e67ceb 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -1,27 +1,40 @@ #!/bin/bash # Bash script to build and push dev docker images to artifact repo # Usage: -# bash build_docker_images.sh -b +# bash build_docker_images.sh -b -f # Make program exit with non-zero exit code if any command fails. set -e -while getopts b: flag +while getopts "b:p:f:" flag; do case "${flag}" in b) GIT_BRANCH=${OPTARG};; + p) PROJECT=${OPTARG};; + f) FRAMEWORK=${OPTARG};; esac done # Artifact repostiory -ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +if [ "$PROJECT" = "mlcommons-algoperf" ]; then + ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +else + ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo" +fi -if [[ -z ${GIT_BRANCH+x} ]] +if [[ -z ${GIT_BRANCH+x} ]]; then GIT_BRANCH='main' # Set default argument fi -for FRAMEWORK in "jax" "pytorch" "both" +FRAMEWORKS=( "jax" "pythorch" "both" ) + +if [[ -n "$FRAMEWORK" ]]; +then + FRAMEWORKS=("$FRAMEWORK") +fi + +for FRAMEWORK in "${FRAMEWORKS[@]}"; do IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" DOCKER_BUILD_COMMAND="docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH" diff --git a/pyproject.toml b/pyproject.toml index 4e15e4400..9054ab44f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,7 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", + "Programming Languate :: Python :: 3.11", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ @@ -38,8 +36,8 @@ dependencies = [ "docker==7.1.0", "numpy>=2.0.2", "pandas>=2.0.1", - "tensorflow==2.18.0", - "tensorflow-datasets==4.9.7", + "tensorflow==2.19.0", + "tensorflow-datasets==4.9.9", "tensorflow-probability==0.20.0", "gputil==1.4.0", "psutil==6.1.0", @@ -91,32 +89,35 @@ fastmri = ["h5py==3.12.0", "scikit-image==0.24.0"] ogbg = ["jraph==0.0.6.dev0", "scikit-learn==1.5.2"] librispeech_conformer = [ "sentencepiece==0.2.0", - "tensorflow-text==2.18.0", + "tensorflow-text==2.19.0", "pydub==0.25.1", ] -wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] +wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] # Frameworks jax_core_deps = [ - "flax==0.8.4", + "flax==0.10.6", "optax==0.2.2", "chex==0.1.86", - "ml_dtypes==0.4.1", + "ml_dtypes==0.5.1", "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", + "jax==0.6.0", "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", - "jax-cuda12-plugin[with_cuda]==0.4.28", - "jax-cuda12-pjrt==0.4.28", + # Temporarily install with -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre + "jax", + "jaxlib", + "jax-cuda12-plugin[with-cuda]", + "jax-cuda12-pjrt", "algoperf[jax_core_deps]", ] -pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] +pytorch_cpu = [ + "torch==2.5.1", + "torchvision==0.20.1" +] pytorch_gpu = [ "torch==2.5.1", "torchvision==0.20.1", diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index dde41fa6d..0479746f1 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -7,8 +7,11 @@ import jax from jax import lax import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax +from algoperf import jax_sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -50,24 +53,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -77,7 +74,7 @@ def _loss_fn(params): model_state, spec.ForwardPassMode.TRAIN, rng, - update_batch_norm=True) + update_batch_norm=True,) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, @@ -90,9 +87,8 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + + # Compute local loss and gradients loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -105,7 +101,7 @@ def _loss_fn(params): grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + current_param_container) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm @@ -130,7 +126,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -139,23 +134,51 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Set up mesh and sharding + mesh = jax.sharding.Mesh(jax.devices(), ('batch')) + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings= ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ), + out_shardings=( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + )) + # print(batch) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing) # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -205,6 +228,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 32 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 0e53aae42..1e42d7b94 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -7,8 +7,11 @@ import jax from jax import lax import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax +from algoperf import jax_sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -37,7 +40,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( @@ -87,21 +90,21 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +# @functools.partial( +# jax.pmap, +# axis_name='batch', +# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), +# static_broadcasted_argnums=(0, 1), +# donate_argnums=(2, 3, 4)) +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -124,9 +127,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # # Get correct global mean loss and grad. loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -164,7 +165,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -173,23 +173,54 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + mesh = jax_sharding_utils.get_mesh() + # Create shardings for each argument + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rngs + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing) # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -239,6 +270,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 128 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index b64f0dfd6..b99d1fd94 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -41,4 +41,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index a6c3d853b..9da67e8f9 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -41,7 +41,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=False) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 597a43c9e..1ba56bbda 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -168,4 +168,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 0c11044fc..533e23f2c 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -41,7 +41,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 217228935..20f015821 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -7,26 +7,21 @@ import jax.numpy as jnp import optax +from algoperf import jax_sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -49,9 +44,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # Compute mean loss and grad loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -98,9 +91,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - new_optimizer_state, new_params, new_model_state, loss, grad_norm = pmapped_train_step( # pylint: disable=line-too-long + mesh = jax_sharding_utils.get_mesh() + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicated_sharding(mesh) # No partitioning + sharded = jax_sharding_utils.get_batch_sharding( + mesh) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step( workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, grad_clip, + current_param_container, batch, rng, grad_clip, label_smoothing) # Log loss, grad_norm. @@ -108,8 +135,8 @@ def update_params( workload.metrics_logger is not None): workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/scoring/plot_utils/plot_curves.py b/scoring/plot_utils/plot_curves.py new file mode 100644 index 000000000..96b5305eb --- /dev/null +++ b/scoring/plot_utils/plot_curves.py @@ -0,0 +1,46 @@ +from absl import flags +from absl import app + +import matplotlib.pyplot as plt +import pandas as pd +import os +import wandb + +flags.DEFINE_string('trial_dir', None, 'Path to trial dir') +flags.DEFINE_string() +FLAGS = flags.FLAGS + +MEASUREMENTS_FILENAME = 'eval_measurements.csv' + +def get_filename(trial_dir): + filename = os.path.join(trial_dir, MEASUREMENTS_FILENAME ) + return filename + +def main(_): + filename = get_filename(FLAGS.trial_dir) + + # Start a new W&B run --- + run = wandb.init(project="visualize-training-curve", name="conformer_jit") + + # Log the CSV as a versioned Artifact --- + artifact = wandb.Artifact(name="training-data", type="dataset") + artifact.add_file(filename) # Directly add the file + run.log_artifact(artifact) + + # Log the metrics for direct visualization --- + df = pd.read_csv(filename) + print(df.columns) + for index, row in df.iterrows(): + metrics = {col : row[col] for col in df.columns} + wandb.log(metrics, step=int(row["global_step"])) + + # Finish the W&B run --- + run.finish() + + + return + + +if __name__ == '__main__': + + app.run(main) \ No newline at end of file diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index f07dc8cdd..154de6060 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -11,6 +11,7 @@ --compute_performance_profiles \ --output_dir scoring_results_self_tuning \ --self_tuning_ruleset + """ import operator @@ -99,6 +100,8 @@ def get_summary_df(workload, workload_df, include_test_split=False): lambda x: x['accumulated_submission_time'][int(x[ 'index to target on val'])] if x['val target reached'] else np.inf, axis=1) + + summary_df['step_time (s)'] = (workload_df['accumulated_submission_time'] / workload_df['global_step']).iloc[-1][-1] # test metrics if include_test_split: diff --git a/scoring/utils/slurm/make_job_config.py b/scoring/utils/slurm/make_job_config.py index 116e70459..04affb852 100644 --- a/scoring/utils/slurm/make_job_config.py +++ b/scoring/utils/slurm/make_job_config.py @@ -13,11 +13,10 @@ from absl import flags import jax -SUBMISSION_PATH = 'submissions_algorithms/submissions/self_tuning/schedule_free_adamw_v2/submission.py' -EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2' -TUNING_SEARCH_SPACE = None -FRAMEWORK = 'pytorch' -TUNING_RULESET = 'self' +SUBMISSION_PATH = 'reference_algorithms/paper_baselines/adamw/jax/submission.py' +TUNING_SEARCH_SPACE = 'reference_algorithms/paper_baselines/adamw/tuning_search_space.json' +NUM_TUNING_TRIALS = 3 # For external tuning ruleset +NUM_STUDIES = 3 flags.DEFINE_string( 'submission_path', @@ -29,27 +28,35 @@ 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.' ) flags.DEFINE_string('experiment_dir', - EXPERIMENT_DIR, + '$HOME/experiments/', 'Path to experiment dir where logs will be saved.') flags.DEFINE_enum( 'framework', - FRAMEWORK, + 'jax', enum_values=['jax', 'pytorch'], help='Can be either pytorch or jax.') flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.') flags.DEFINE_enum( 'tuning_ruleset', - TUNING_RULESET, + 'external', enum_values=['external', 'self'], help='Which tuning ruleset to score this submission on. Can be external or self.' ) +flags.DEFINE_string( + 'workloads', + None, + help='Comma seperated list of workloads to run.' +) +flags.DEFINE_integer( + 'num_studies', + NUM_STUDIES, + help='Number of studies.' +) FLAGS = flags.FLAGS MIN_INT = -2**(31) MAX_INT = 2**(31) - 1 -NUM_TUNING_TRIALS = 5 # For external tuning ruleset -NUM_STUDIES = 3 WORKLOADS = { "imagenet_resnet": {"dataset": "imagenet"}, @@ -64,7 +71,11 @@ def main(_): - workloads = WORKLOADS.keys() + if not FLAGS.workloads: + workloads = WORKLOADS.keys() + else: + workloads = FLAGS.workloads.split(',') + key = jax.random.key(FLAGS.seed) jobs = [] diff --git a/submission_runner.py b/submission_runner.py index 468a04c7c..657238351 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,6 +31,10 @@ from absl import logging import jax import tensorflow as tf + +# New PRNG implementation for correct sharding +jax.config.update('jax_default_prng_impl', 'threefry2x32') +jax.config.update('jax_threefry_partitionable', True) import torch import torch.distributed as dist @@ -162,6 +166,12 @@ 'Number of workers for ImageNet PyTorch evaluation data loaders.' 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' 'in incorrect evals currently, see issues/732.') +flags.DEFINE_boolean('capture_jax_trace', + False, + 'Captures jax profiler trace and writes to experiment directory.') +flags.DEFINE_boolean('skip_evals', + False, + 'Skip evals on train eval, validation and test splits.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -209,7 +219,8 @@ def train_once( profiler: Profiler, max_global_steps: int = None, log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True + save_checkpoints: Optional[bool] = True, + skip_evals: Optional[bool] = False, ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) @@ -385,7 +396,7 @@ def train_once( # Check if submission is eligible for an untimed eval. if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + workload.eval_period_time_sec or train_state['training_complete']) and not skip_evals: # Prepare for evaluation (timed). if prepare_for_eval is not None: @@ -547,7 +558,9 @@ def score_submission_on_workload(workload: spec.Workload, save_checkpoints: Optional[bool] = True, hparam_start_index: Optional[bool] = None, hparam_end_index: Optional[bool] = None, - rng_seed: Optional[int] = None): + rng_seed: Optional[int] = None, + capture_trace: Optional[bool] = False, + skip_evals: Optional[bool] = False): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -627,6 +640,9 @@ def score_submission_on_workload(workload: spec.Workload, tuning_search_space[hi] = hyperparameters with profiler.profile('Train'): + if capture_trace: + logging.info(f'Capturing and saving jax trace to {log_dir}') + jax.profiler.start_trace(f'{log_dir}/traces'), timing, metrics = train_once(workload, workload_name, global_batch_size, global_eval_batch_size, @@ -640,7 +656,10 @@ def score_submission_on_workload(workload: spec.Workload, profiler, max_global_steps, tuning_dir_name, - save_checkpoints=save_checkpoints,) + save_checkpoints=save_checkpoints, + skip_evals=skip_evals) + if capture_trace: + jax.profiler.stop_trace() all_timings[hi] = timing all_metrics[hi] = metrics logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}') @@ -665,12 +684,17 @@ def score_submission_on_workload(workload: spec.Workload, logging.info(f'Creating directory at {log_dir}.') logger_utils.makedir(log_dir) with profiler.profile('Train'): + if capture_trace: + jax.profiler.start_trace('/algoperf/traces'), + logging.info(f'Capturing and saving jax trace to {log_dir}') score, _ = train_once( workload, workload_name, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, prepare_for_eval, None, rng_seed, rng, profiler, max_global_steps, log_dir, save_checkpoints=save_checkpoints) + if capture_trace: + jax.profiler.stop_trace() return score @@ -746,7 +770,9 @@ def main(_): save_checkpoints=FLAGS.save_checkpoints, hparam_start_index=FLAGS.hparam_start_index, hparam_end_index=FLAGS.hparam_end_index, - rng_seed=FLAGS.rng_seed) + rng_seed=FLAGS.rng_seed, + capture_trace=FLAGS.capture_jax_trace, + skip_evals=FLAGS.skip_evals,) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index c4ca514a8..f576d136b 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -225,7 +225,7 @@ def _build_input_queue(self, *args, **kwargs): del kwargs np.random.seed(42) - if framework == 'jax' or USE_PYTORCH_DDP: + if USE_PYTORCH_DDP: batch_shape = (n_gpus, global_batch_size // n_gpus) else: batch_shape = (global_batch_size,) @@ -422,6 +422,10 @@ def _test_submission(workload_name, global_batch_size = FLAGS.global_batch_size if FLAGS.global_batch_size < 0: raise ValueError('Must set --global_batch_size.') + elif global_batch_size < n_gpus and FLAGS.framework == 'jax': + raise ValueError( + 'Global batch size cannot be smaller than the number of GPUs when using JAX sharding.' + ) workload = _make_one_batch_workload(workload_class, workload_name, framework, diff --git a/tests/test_jax_sharding_invariance.py b/tests/test_jax_sharding_invariance.py new file mode 100644 index 000000000..823ef0843 --- /dev/null +++ b/tests/test_jax_sharding_invariance.py @@ -0,0 +1,90 @@ +"""Tests for sharding consistency in JAX workloads. + +Specifically this will test the model_init functions, and input_pipeline. +""" +import copy +import os +import sys + +from absl import flags +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized + +from algoperf.profiler import PassThroughProfiler +import submission_runner +from algoperf.workloads.workloads import import_workload +from algoperf.workloads.workloads import BASE_WORKLOADS_DIR +from algoperf.workloads.workloads import WORKLOADS + +FLAGS = flags.FLAGS +# Needed to avoid UnparsedFlagAccessError +# (see https://github.com/google/model_search/pull/8). +FLAGS(sys.argv) + +FRAMEWORK = 'jax' # Can extend to pytorch later + + +test_case = dict(testcase_name='test_ogbg', + workload='ogbg') + + +class SubmissionRunnerTest(parameterized.TestCase): + """Tests for reference submissions.""" + + + @parameterized.named_parameters(test_case) + def test_invariance(self, workload_name): + workload_name = 'ogbg' + dataset_dir = f'/data/{workload_name}' + workload_metadata = copy.deepcopy(WORKLOADS[workload_name]) + workload_metadata['workload_path'] = os.path.join(BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + FRAMEWORK, + 'workload.py') + workload = import_workload(workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs={}) + + rng = jax.random.PRNGKey(0) + initial_params, model_state = workload.init_model_fn(rng) + data_iter = workload._build_input_queue(rng, 'train', dataset_dir, 32) + batch = next(data_iter) + inputs = batch['inputs'] + + def forward_pass(params, + batch, + model_state, + rng,): + logits, _ = workload.model_fn(initial_params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + return logits + + forward_pass_jitted = jax.jit(forward_pass, + in_shardings=(jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_replicate_sharding(), + ), + out_shardings=jax_sharding_utils.get_batch_dim_sharding()) + + logits = forward_pass(initial_params, + batch, + model_state, + rng,) + + logits_jitted = forward_pass_jitted(initial_params, + batch, + model_state, + rng,) + + jax.debug.visualize_array_sharding(logits_jitted) + + equal = jnp.allclose(logits, logits_jitted, atol=1e-6) + + +if __name__ == '__main__': + absltest.main() \ No newline at end of file