From ae48ccdb9aaedc80bae6adcc4a483abdb7c191af Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 21 Nov 2024 11:56:05 -0500 Subject: [PATCH 01/68] Use jax.jit for sharding initial steps Apply it to the MNIST workload and the Nesterov optimizer. --- algoperf/checkpoint_utils.py | 3 - algoperf/data_utils.py | 5 +- algoperf/sharding_utils.py | 62 +++++++++++++++++++ .../workloads/mnist/mnist_jax/workload.py | 23 ++++--- .../nesterov/jax/submission.py | 62 ++++++++++++++----- 5 files changed, 123 insertions(+), 32 deletions(-) create mode 100644 algoperf/sharding_utils.py diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index f4cb6c2db..26ee5ee59 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -193,10 +193,7 @@ def save_checkpoint(framework: str, train_state, eval_results, global_step, preemption_count). """ if framework == 'jax': - model_params = jax.device_get(jax_utils.unreplicate(model_params)) opt_state, _ = optimizer_state - opt_state = jax.device_get(jax_utils.unreplicate(opt_state)) - model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: if isinstance( model_params, diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..bac9155b5 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -60,10 +60,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - # Reshape (global_batch_size, ...) to - # (local_device_count, per_device_batch_size, ...). - # Assumes that `global_batch_size % local_device_count == 0`. - return x.reshape((local_device_count, -1, *x.shape[1:])) + return x return jax.tree.map(_prepare, batch) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py new file mode 100644 index 000000000..4950f243a --- /dev/null +++ b/algoperf/sharding_utils.py @@ -0,0 +1,62 @@ +"""Utilities for dealing with sharding in JAX.""" + +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec + + +def get_mesh() -> jax.sharding.Mesh: + """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" + return jax.sharding.Mesh(jax.devices(), ("batch",)) + +def get_replicated_sharding(mesh=None): + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + +def get_naive_sharding_spec(mesh=None): + """Returns a sharding spec that shards data along the first axis.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec("batch")) + + +def get_naive_sharding(x, mesh=None): + """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" + if mesh is None: + mesh = get_mesh() + grid_size = mesh.shape["batch"] + if x.shape[0] % grid_size == 0: + return NamedSharding(mesh, PartitionSpec("batch")) + else: + return NamedSharding(mesh, PartitionSpec()) + + +def shard_params(params, mesh=None): + """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" + if mesh is None: + mesh = get_mesh() + return jax.tree_util.tree_map( + lambda x: jax.device_put(x, get_naive_sharding(x)), params + ) + + +def get_sharding_tree(params, mesh=None): + """Returns a sharding tree for a parameter tree.""" + return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) + + +def get_empty_sharding(mesh=None): + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + + +def disp_shard_info(x: jax.Array): + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print( + f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n" + ) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5a4382da1..11429bae6 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -10,9 +10,9 @@ import jax.numpy as jnp import optax -from algoperf import param_utils -from algoperf import spec -from algoperf.workloads.mnist.workload import BaseMnistWorkload +from algorithmic_efficiency import param_utils, sharding_utils +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload class _Model(nn.Module): @@ -46,7 +46,7 @@ def init_model_fn( train=True)['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_1' @@ -101,10 +101,13 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + static_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, @@ -125,11 +128,11 @@ def _eval_model( (jnp.argmax(logits, axis=-1) == batch['targets']) * weights) summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed'] metrics = {'accuracy': accuracy, 'loss': summed_loss} - metrics = lax.psum(metrics, axis_name='batch') return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) + total_metrics = {'accuracy': total_metrics['accuracy'].item() / num_examples, 'loss': total_metrics['loss'].item() / num_examples} + return total_metrics diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 0e53aae42..81f23b171 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -6,10 +6,11 @@ from flax import jax_utils import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp import optax -from algoperf import spec +from algorithmic_efficiency import spec, sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -37,7 +38,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( @@ -87,13 +88,13 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, +# @functools.partial( +# jax.pmap, +# axis_name='batch', +# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), +# static_broadcasted_argnums=(0, 1), +# donate_argnums=(2, 3, 4)) +def train_step(workload, opt_update_fn, model_state, optimizer_state, @@ -124,9 +125,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # # Get correct global mean loss and grad. loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -173,7 +172,40 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, + + mesh = sharding_utils.get_mesh() + # Create shardings for each argument + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # per_device_rngs + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings + ) + outputs = jitted_train_step(workload, opt_update_fn, model_state, optimizer_state, @@ -188,8 +220,8 @@ def update_params( if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state From eb5cac7ba5854e1b88476d46abe6169860211144 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 21 Nov 2024 12:14:09 -0500 Subject: [PATCH 02/68] Use jax.jit for adamw --- .../paper_baselines/adamw/jax/submission.py | 97 ++++++++++++------- 1 file changed, 63 insertions(+), 34 deletions(-) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index dde41fa6d..d017da551 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -6,10 +6,11 @@ from flax import jax_utils import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp import optax -from algoperf import spec +from algoperf import spec, sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -50,24 +51,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -90,9 +85,8 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + + # Compute local loss and gradients loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -105,7 +99,7 @@ def _loss_fn(params): grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + current_param_container) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm @@ -139,23 +133,58 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + + # Set up mesh and sharding + mesh = sharding_utils.get_mesh() + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Define input and output shardings + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings + ) + + outputs = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state From 82977da560637654c93dd7e511be85e2006dcfaa Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 9 Dec 2024 01:53:50 -0500 Subject: [PATCH 03/68] Pass yapf checks --- algoperf/sharding_utils.py | 71 +++++++++---------- .../workloads/mnist/mnist_jax/workload.py | 16 +++-- algoperf/workloads/mnist/workload.py | 7 +- .../paper_baselines/adamw/jax/submission.py | 19 +++-- .../nesterov/jax/submission.py | 59 +++++++-------- 5 files changed, 86 insertions(+), 86 deletions(-) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py index 4950f243a..62a441bc9 100644 --- a/algoperf/sharding_utils.py +++ b/algoperf/sharding_utils.py @@ -5,58 +5,57 @@ def get_mesh() -> jax.sharding.Mesh: - """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" - return jax.sharding.Mesh(jax.devices(), ("batch",)) + """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" + return jax.sharding.Mesh(jax.devices(), ("batch",)) + def get_replicated_sharding(mesh=None): - """Returns a sharding spec that replicates data across all devices.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec()) + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + def get_naive_sharding_spec(mesh=None): - """Returns a sharding spec that shards data along the first axis.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec("batch")) + """Returns a sharding spec that shards data along the first axis.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec("batch")) def get_naive_sharding(x, mesh=None): - """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" - if mesh is None: - mesh = get_mesh() - grid_size = mesh.shape["batch"] - if x.shape[0] % grid_size == 0: - return NamedSharding(mesh, PartitionSpec("batch")) - else: - return NamedSharding(mesh, PartitionSpec()) + """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" + if mesh is None: + mesh = get_mesh() + grid_size = mesh.shape["batch"] + if x.shape[0] % grid_size == 0: + return NamedSharding(mesh, PartitionSpec("batch")) + else: + return NamedSharding(mesh, PartitionSpec()) def shard_params(params, mesh=None): - """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" - if mesh is None: - mesh = get_mesh() - return jax.tree_util.tree_map( - lambda x: jax.device_put(x, get_naive_sharding(x)), params - ) + """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" + if mesh is None: + mesh = get_mesh() + return jax.tree_util.tree_map( + lambda x: jax.device_put(x, get_naive_sharding(x)), params) def get_sharding_tree(params, mesh=None): - """Returns a sharding tree for a parameter tree.""" - return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) + """Returns a sharding tree for a parameter tree.""" + return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) def get_empty_sharding(mesh=None): - """Returns a sharding spec that replicates data across all devices.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec()) + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) def disp_shard_info(x: jax.Array): - """Displays shard info of a jax array.""" - for shard in x.addressable_shards: - print( - f"shard.device: {shard.device}, index: {shard.index}, replica_id:" - f" {shard.replica_id}.\n" - ) + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n") diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 11429bae6..b0a52d77f 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -102,11 +102,12 @@ def loss_fn( @functools.partial( jax.jit, - in_shardings=(sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_spec(), # batch - sharding_utils.get_replicated_sharding(), # model_state - sharding_utils.get_naive_sharding_spec(), # rng - ), + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), static_argnums=(0,)) def _eval_model( self, @@ -134,5 +135,8 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - total_metrics = {'accuracy': total_metrics['accuracy'].item() / num_examples, 'loss': total_metrics['loss'].item() / num_examples} + total_metrics = { + 'accuracy': total_metrics['accuracy'].item() / num_examples, + 'loss': total_metrics['loss'].item() / num_examples + } return total_metrics diff --git a/algoperf/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py index f53aadd0b..c92dd141b 100644 --- a/algoperf/workloads/mnist/workload.py +++ b/algoperf/workloads/mnist/workload.py @@ -46,8 +46,7 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'], - }) + 'targets': x['label'],}) is_train = split == 'train' if cache: @@ -214,8 +213,6 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index d017da551..73f41adbe 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -165,18 +165,17 @@ def update_params( static_argnums=(0, 1), donate_argnums=(2, 3, 4), in_shardings=arg_shardings, - out_shardings=out_shardings - ) + out_shardings=out_shardings) outputs = jitted_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 81f23b171..0c648a156 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -95,14 +95,14 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): # static_broadcasted_argnums=(0, 1), # donate_argnums=(2, 3, 4)) def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -182,20 +182,20 @@ def update_params( arg_shardings = ( # workload is static # opt_update_fn is static - replicated, # model_state - replicated, # optimizer_state - replicated, # current_param_container - sharded, # batch - sharded, # per_device_rngs + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # per_device_rngs replicated, # grad_clip - replicated # label_smoothing + replicated # label_smoothing ) out_shardings = ( - replicated, # new_optimizer_state - replicated, # updated_params - replicated, # new_model_state - replicated, # loss - replicated # grad_norm + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm ) # Jit with shardings jitted_train_step = jax.jit( @@ -203,17 +203,16 @@ def update_params( static_argnums=(0, 1), donate_argnums=(2, 3, 4), in_shardings=arg_shardings, - out_shardings=out_shardings - ) + out_shardings=out_shardings) outputs = jitted_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. @@ -271,6 +270,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 128 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From 99545d4082569d8f98c9314836f8f349709405e4 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 9 Dec 2024 01:58:03 -0500 Subject: [PATCH 04/68] CIFAR workload sharding --- .../cifar/cifar_jax/input_pipeline.py | 4 +- .../workloads/cifar/cifar_jax/workload.py | 153 ++++++++++++------ 2 files changed, 103 insertions(+), 54 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 728d05f29..18cb9ac5b 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -8,7 +8,6 @@ import functools from typing import Dict, Iterator, Tuple -from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds @@ -171,5 +170,6 @@ def create_input_iter( functools.partial( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) - it = jax_utils.prefetch_to_device(it, 2) + # FIXME(rka97): Figure out how to do prefetching+sharding. + # it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index ad43bc62f..7dd883f1e 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -3,7 +3,6 @@ import functools from typing import Any, Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax import linen as nn from flax.core import pop import jax @@ -12,7 +11,7 @@ import optax import tensorflow_datasets as tfds -from algoperf import param_utils +from algoperf import param_utils, sharding_utils from algoperf import spec from algoperf.workloads.cifar.cifar_jax import models from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter @@ -28,15 +27,16 @@ def _build_cifar_dataset( data_dir: str, batch_size: int, cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + repeat_final_dataset: Optional[bool] = None, ) -> Iterator[Dict[str, spec.Tensor]]: - ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) - train = split == 'train' + data_dir = data_dir + "/cifar10" + ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir) + train = split == "train" assert self.num_train_examples + self.num_validation_examples == 50000 - if split in ['train', 'eval_train']: - split = f'train[:{self.num_train_examples}]' - elif split == 'validation': - split = f'train[{self.num_train_examples}:]' + if split in ["train", "eval_train"]: + split = f"train[:{self.num_train_examples}]" + elif split == "validation": + split = f"train[{self.num_train_examples}:]" ds = create_input_iter( split, ds_builder, @@ -48,7 +48,8 @@ def _build_cifar_dataset( self.padding_size, train=train, cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset) + repeat_final_dataset=repeat_final_dataset, + ) return ds def _build_input_queue( @@ -59,7 +60,8 @@ def _build_input_queue( global_batch_size: int, cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches return self._build_cifar_dataset(data_rng, split, @@ -74,34 +76,45 @@ def sync_batch_stats( # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics # and we average them. +<<<<<<< variant A avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') new_model_state = model_state.copy() new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) +>>>>>>> variant B + avg_fn = jax.pmap(lambda x: lax.pmean(x, "x"), "x") + new_model_state = model_state.copy( + {"batch_stats": avg_fn(model_state["batch_stats"])}) +======= end return new_model_state def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate - model_cls = getattr(models, 'ResNet18') + model_cls = getattr(models, "ResNet18") model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) - variables = jax.jit(model.init)({'params': rng}, + variables = jax.jit(model.init)({"params": rng}, jnp.ones(input_shape, model.dtype)) +<<<<<<< variant A model_state, params = pop(variables, 'params') +>>>>>>> variant B + model_state, params = variables.pop("params") +======= end self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + # model_state = jax_utils.replicate(model_state) + # params = jax_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: - return param_key == 'Dense_0' + return param_key == "Dense_0" def model_fn( self, @@ -111,26 +124,38 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, +<<<<<<< variant A use_running_average_bn: Optional[bool] = None +>>>>>>> variant B +======= end ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng - variables = {'params': params, **model_state} + variables = {"params": params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, +<<<<<<< variant A mutable=['batch_stats'], use_running_average_bn=use_running_average_bn) +>>>>>>> variant B + mutable=["batch_stats"], + ) +======= end return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, mutable=False, +<<<<<<< variant A use_running_average_bn=use_running_average_bn) +>>>>>>> variant B + ) +======= end return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in @@ -140,13 +165,15 @@ def loss_fn( label_batch: spec.Tensor, # Dense or one-hot labels. logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ + Return {'summed': scalar summed loss, + 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( @@ -159,51 +186,73 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + "summed": summed_loss, + "n_valid_examples": n_valid_examples, + "per_example": per_example_losses, } def _compute_metrics(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor) -> Dict[str, spec.Tensor]: - summed_loss = self.loss_fn(labels, logits, weights)['summed'] + summed_loss = self.loss_fn(labels, logits, weights)["summed"] # Number of correct predictions. accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) - metrics = { - 'loss': summed_loss, - 'accuracy': accuracy, - } - metrics = lax.psum(metrics, axis_name='batch') - return metrics + return jnp.array(summed_loss), jnp.array(accuracy) - @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) - weights = batch.get('weights') - if weights is None: - weights = jnp.ones(len(logits)) - return self._compute_metrics(logits, batch['targets'], weights) + + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + ) + def _per_device_eval_model( + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) + weights = batch.get("weights") + if weights is None: + weights = jnp.ones(len(logits)) + return self._compute_metrics(logits, batch["targets"], weights) + + losses, accuracies = _per_device_eval_model(params, batch, model_state, rng) + metrics = { + "loss": + jnp.mean(losses, axis=0) if losses.ndim > 0 else losses, + "accuracy": + (jnp.mean(accuracies, axis=0) if accuracies.ndim > 0 else accuracies + ), + } + return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" +<<<<<<< variant A return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) +>>>>>>> variant B + return jax.tree_map(lambda x: x / num_examples, total_metrics) +======= end From 018711a77c43bbf2477d9bfbd8ec1971572db222 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 7 Jan 2025 21:18:44 +0000 Subject: [PATCH 05/68] librispeech_conformer now running Still need to test out (a) output losses, (b) speed, and (c) look into other librispeech. --- .../librispeech_jax/workload.py | 101 +++++++++++++----- .../nesterov/jax/submission.py | 5 +- 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 39012a20d..b2f5a2903 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -7,6 +7,9 @@ import flax.linen as nn import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P + +from algoperf import sharding_utils import jax.numpy as jnp import numpy as np import optax @@ -21,7 +24,6 @@ LibriSpeechDataset from algoperf.workloads.librispeech_conformer.librispeech_jax import models - class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): def __init__(self, @@ -93,8 +95,16 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + + # Add sharding + mesh = sharding_utils.get_mesh() + params = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + params) + model_state = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + model_state) + return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -180,6 +190,7 @@ def _build_input_queue( 'targets': (targets.numpy(), target_paddings.numpy()), } + # Use data_utils.shard_and_maybe_pad_np to handle sharding padded_batch = data_utils.shard_and_maybe_pad_np( numpy_batch, padding_value=1.0) yield padded_batch @@ -305,11 +316,16 @@ def greedy_decode( return hyp, hyp_paddings @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) - def eval_step_pmapped( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_replicated_sharding(), # rng + ), + out_shardings=sharding_utils.get_naive_sharding_spec(), + static_argnums=(0,)) + def _eval_step( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -327,13 +343,39 @@ def eval_step_pmapped( loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) targets, target_paddings = batch['targets'] - return self.metrics_bundle.gather_from_model_output( - loss_dict=loss, - decoded=decoded, - decoded_paddings=decoded_paddings, - targets=targets, - target_paddings=target_paddings, - axis_name='batch') + # Convert metrics bundle to dictionary + metrics_dict = { + 'loss_per_example': loss['per_example'], + 'decoded': decoded, + 'decoded_paddings': decoded_paddings, + 'targets': targets, + 'target_paddings': target_paddings, + 'n_valid_examples': jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] + } + return metrics_dict + + def eval_step( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState): + """Evaluates the model and returns a metrics bundle.""" + metrics_dict = self._eval_step(params, batch, model_state, rng) + + # Convert dictionary back to metrics bundle + metrics = self.metrics_bundle.single_from_model_output( + loss_dict={ + 'summed': metrics_dict['loss_per_example'].sum(), + 'per_example': metrics_dict['loss_per_example'], + 'n_valid_examples': metrics_dict['n_valid_examples'].sum() + }, + decoded=metrics_dict['decoded'], + decoded_paddings=metrics_dict['decoded_paddings'], + targets=metrics_dict['targets'], + target_paddings=metrics_dict['target_paddings']) + + return metrics def _eval_model_on_split(self, split: str, @@ -358,10 +400,10 @@ def _eval_model_on_split(self, metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step_pmapped(params, - eval_batch, - model_state, - rng).unreplicate() + computed_metrics = self.eval_step(params, + eval_batch, + model_state, + rng) if metrics_report is None: metrics_report = computed_metrics @@ -373,15 +415,22 @@ def _eval_model_on_split(self, return computed_metrics + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # model_state + ), + out_shardings=sharding_utils.get_replicated_sharding(), + static_argnums=(0,) + ) def sync_batch_stats( self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state + """Sync batch statistics across replicas.""" + # Replace pmean with direct mean across devices + new_batch_stats = jax.tree_map( + lambda x: jnp.mean(x, axis=0), + model_state['batch_stats']) + return model_state.copy({'batch_stats': new_batch_stats}) class LibriSpeechConformerAttentionTemperatureWorkload( diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 0c648a156..544ba3c13 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -163,7 +163,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -186,7 +185,7 @@ def update_params( replicated, # optimizer_state replicated, # current_param_container sharded, # batch - sharded, # per_device_rngs + replicated, # rngs replicated, # grad_clip replicated # label_smoothing ) @@ -210,7 +209,7 @@ def update_params( optimizer_state, current_param_container, batch, - per_device_rngs, + rng, grad_clip, label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs From fbeb5f1b1fa184c9fcb881058564f856bc2e89c7 Mon Sep 17 00:00:00 2001 From: rka97 Date: Wed, 5 Feb 2025 14:33:39 +0000 Subject: [PATCH 06/68] fix formatting --- algoperf/sharding_utils.py | 4 +- .../librispeech_jax/models.py | 10 ++-- .../librispeech_jax/spectrum_augmenter.py | 4 +- .../librispeech_jax/workload.py | 60 ++++++++++--------- .../librispeech_jax/models.py | 10 ++-- .../workloads/mnist/mnist_jax/workload.py | 3 +- .../paper_baselines/adamw/jax/submission.py | 17 +++--- .../nesterov/jax/submission.py | 9 +-- submission_runner.py | 2 + 9 files changed, 61 insertions(+), 58 deletions(-) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py index 62a441bc9..93a4dd53f 100644 --- a/algoperf/sharding_utils.py +++ b/algoperf/sharding_utils.py @@ -1,7 +1,9 @@ """Utilities for dealing with sharding in JAX.""" import jax -from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec def get_mesh() -> jax.sharding.Mesh: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 593d463c3..2bb527a36 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,12 +442,10 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index 2a6f73d4d..c16740629 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights + < multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index b2f5a2903..d68974cd8 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -7,15 +7,14 @@ import flax.linen as nn import jax from jax import lax -from jax.sharding import NamedSharding, PartitionSpec as P - -from algoperf import sharding_utils import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P import numpy as np import optax import torch from algoperf import data_utils +from algoperf import sharding_utils from algoperf import param_utils from algoperf import spec from algoperf.workloads.librispeech_conformer import metrics @@ -24,6 +23,7 @@ LibriSpeechDataset from algoperf.workloads.librispeech_conformer.librispeech_jax import models + class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): def __init__(self, @@ -99,10 +99,12 @@ def init_model_fn( # Add sharding mesh = sharding_utils.get_mesh() params = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), params) model_state = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), model_state) return params, model_state @@ -345,30 +347,35 @@ def _eval_step( targets, target_paddings = batch['targets'] # Convert metrics bundle to dictionary metrics_dict = { - 'loss_per_example': loss['per_example'], - 'decoded': decoded, - 'decoded_paddings': decoded_paddings, - 'targets': targets, - 'target_paddings': target_paddings, - 'n_valid_examples': jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] + 'loss_per_example': + loss['per_example'], + 'decoded': + decoded, + 'decoded_paddings': + decoded_paddings, + 'targets': + targets, + 'target_paddings': + target_paddings, + 'n_valid_examples': + jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] } return metrics_dict - def eval_step( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState): + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState): """Evaluates the model and returns a metrics bundle.""" metrics_dict = self._eval_step(params, batch, model_state, rng) # Convert dictionary back to metrics bundle metrics = self.metrics_bundle.single_from_model_output( loss_dict={ - 'summed': metrics_dict['loss_per_example'].sum(), - 'per_example': metrics_dict['loss_per_example'], - 'n_valid_examples': metrics_dict['n_valid_examples'].sum() + 'summed': metrics_dict['loss_per_example'].sum(), + 'per_example': metrics_dict['loss_per_example'], + 'n_valid_examples': metrics_dict['n_valid_examples'].sum() }, decoded=metrics_dict['decoded'], decoded_paddings=metrics_dict['decoded_paddings'], @@ -400,10 +407,7 @@ def _eval_model_on_split(self, metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step(params, - eval_batch, - model_state, - rng) + computed_metrics = self.eval_step(params, eval_batch, model_state, rng) if metrics_report is None: metrics_report = computed_metrics @@ -421,15 +425,13 @@ def _eval_model_on_split(self, sharding_utils.get_replicated_sharding(), # model_state ), out_shardings=sharding_utils.get_replicated_sharding(), - static_argnums=(0,) - ) + static_argnums=(0,)) def sync_batch_stats( self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: """Sync batch statistics across replicas.""" # Replace pmean with direct mean across devices - new_batch_stats = jax.tree_map( - lambda x: jnp.mean(x, axis=0), - model_state['batch_stats']) + new_batch_stats = jax.tree_map(lambda x: jnp.mean(x, axis=0), + model_state['batch_stats']) return model_state.copy({'batch_stats': new_batch_stats}) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b116f44cd..c6e7275be 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -139,8 +139,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -273,12 +273,10 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index b0a52d77f..b57dd72dd 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -10,7 +10,8 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import param_utils, sharding_utils +from algorithmic_efficiency import param_utils +from algorithmic_efficiency import sharding_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 73f41adbe..156c7ab20 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -6,8 +6,9 @@ from flax import jax_utils import jax from jax import lax -from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax from algoperf import spec, sharding_utils @@ -124,7 +125,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -146,17 +146,17 @@ def update_params( replicated, # model_state replicated, # optimizer_state replicated, # current_param_container - sharded, # batch - sharded, # rng + sharded, # batch + replicated, # rng replicated, # grad_clip - replicated # label_smoothing + replicated # label_smoothing ) out_shardings = ( replicated, # new_optimizer_state replicated, # updated_params replicated, # new_model_state replicated, # loss - replicated # grad_norm + replicated # grad_norm ) # Jit with shardings @@ -167,16 +167,15 @@ def update_params( in_shardings=arg_shardings, out_shardings=out_shardings) - outputs = jitted_train_step(workload, + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, optimizer_state, current_param_container, batch, - per_device_rngs, + rng, grad_clip, label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 544ba3c13..5b332030b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -6,11 +6,13 @@ from flax import jax_utils import jax from jax import lax -from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax -from algorithmic_efficiency import spec, sharding_utils +from algorithmic_efficiency import sharding_utils +from algorithmic_efficiency import spec _GRAD_CLIP_EPS = 1e-6 @@ -203,7 +205,7 @@ def update_params( donate_argnums=(2, 3, 4), in_shardings=arg_shardings, out_shardings=out_shardings) - outputs = jitted_train_step(workload, + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, optimizer_state, @@ -212,7 +214,6 @@ def update_params( rng, grad_clip, label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: diff --git a/submission_runner.py b/submission_runner.py index a2521e77b..ce61bb581 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,6 +31,8 @@ from absl import logging import jax import tensorflow as tf +jax.config.update('jax_default_prng_impl', 'threefry2x32') +jax.config.update('jax_threefry_partitionable', True) import torch import torch.distributed as dist From 6e4e7b0023c025e68bcbb05ef1e37de97886a8f8 Mon Sep 17 00:00:00 2001 From: rka97 Date: Wed, 5 Feb 2025 14:34:01 +0000 Subject: [PATCH 07/68] shard default --- algoperf/data_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index bac9155b5..6fede4430 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -11,7 +11,7 @@ from torch.utils.data import DistributedSampler from torch.utils.data import Sampler -from algoperf import spec +from algoperf import spec, sharding_utils def shard_and_maybe_pad_np( @@ -60,7 +60,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - return x + return jax.device_put(x, sharding_utils.get_naive_sharding_spec()) # x return jax.tree.map(_prepare, batch) From 4a2c02d3155e28622b24c014774741edd50b4736 Mon Sep 17 00:00:00 2001 From: rka97 Date: Wed, 5 Feb 2025 14:34:08 +0000 Subject: [PATCH 08/68] start imagenet --- algoperf/data_utils.py | 4 +- .../imagenet_jax/input_pipeline.py | 2 +- .../imagenet_resnet/imagenet_jax/workload.py | 50 +++++++++++++------ 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 6fede4430..44eca8c26 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -50,7 +50,7 @@ def shard_and_maybe_pad_np( weights = batch.get('weights') # The weights will also be padded. batch['weights'] = np.ones(mask_shape) if weights is None else weights - + naive_sharding_spec = sharding_utils.get_naive_sharding_spec() def _prepare(x): # Use _numpy() for zero-copy conversion between TF and NumPy. if not isinstance(x, np.ndarray): @@ -60,7 +60,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - return jax.device_put(x, sharding_utils.get_naive_sharding_spec()) # x + return jax.device_put(x, naive_sharding_spec) return jax.tree.map(_prepare, batch) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 66105335b..b1be5ac1f 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -399,6 +399,6 @@ def create_input_iter(split: str, ds) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. - it = jax_utils.prefetch_to_device(it, 2) + # it = jax_utils.prefetch_to_device(it, 2) return iter(it) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 4ec3937b8..b60fd2753 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -19,6 +19,7 @@ import tensorflow_datasets as tfds from algoperf import param_utils +from algoperf import sharding_utils from algoperf import random_utils as prng from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 @@ -71,16 +72,20 @@ def _build_dataset( use_randaug=use_randaug) return ds + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # model_state + ), + out_shardings=sharding_utils.get_replicated_sharding(), + static_argnums=(0,)) def sync_batch_stats( self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() # Create a shallow copy - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state + """Sync batch statistics across replicas.""" + new_batch_stats = jax.tree_map(lambda x: jnp.mean(x, axis=0), + model_state['batch_stats']) + return model_state.copy({'batch_stats': new_batch_stats}) + def init_model_fn( self, @@ -113,18 +118,30 @@ def init_model_fn( model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + mesh = sharding_utils.get_mesh() + params = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), + params) + model_state = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), + model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_replicated_sharding(), # rng + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_naive_sharding_spec()) def _eval_model(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -218,7 +235,7 @@ def _compute_metrics(self, 'loss': summed_loss, 'accuracy': accuracy, } - metrics = lax.psum(metrics, axis_name='batch') + # metrics = lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -252,11 +269,12 @@ def _eval_model_on_split(self, eval_rng = prng.fold_in(eval_rng, bi) step_eval_rngs = prng.split(eval_rng, jax.local_device_count()) batch = next(self._eval_iters[split]) - # We already average these metrics across devices inside _compute_metrics. synced_metrics = self._eval_model(params, batch, model_state, step_eval_rngs) + # Sum up the synced metrics + synced_metrics = jax.tree_map(lambda x: jnp.sum(x, axis=0), synced_metrics) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 From 47beba15f9b6a2d6262ce228e5e4b8b82ae53c70 Mon Sep 17 00:00:00 2001 From: rka97 Date: Wed, 5 Feb 2025 14:34:13 +0000 Subject: [PATCH 09/68] remove bn sync in imagenet (jit handles it automatically) --- .../workloads/imagenet_resnet/imagenet_jax/workload.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index b60fd2753..d7285ab76 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -141,7 +141,7 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: sharding_utils.get_replicated_sharding(), # rng ), static_argnums=(0,), - out_shardings=sharding_utils.get_naive_sharding_spec()) + out_shardings=sharding_utils.get_replicated_sharding()) def _eval_model(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -248,9 +248,6 @@ def _eval_model_on_split(self, data_dir: str, global_step: int = 0) -> Dict[str, float]: del global_step - if model_state is not None: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) num_batches = int(math.ceil(num_examples / global_batch_size)) data_rng, eval_rng = prng.split(rng, 2) # We already repeat the dataset indefinitely in tf.data. @@ -273,14 +270,12 @@ def _eval_model_on_split(self, batch, model_state, step_eval_rngs) - # Sum up the synced metrics - synced_metrics = jax.tree_map(lambda x: jnp.sum(x, axis=0), synced_metrics) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), + eval_metrics = jax.tree.map(lambda x: x / num_examples, eval_metrics) return eval_metrics From 3a18f19c87e0869eceb12894fa4bb7a636a276af Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 6 Feb 2025 18:08:16 +0000 Subject: [PATCH 10/68] ImageNet-ViT also works --- algoperf/sharding_utils.py | 5 +++++ .../imagenet_resnet/imagenet_jax/workload.py | 19 +------------------ .../imagenet_vit/imagenet_jax/workload.py | 7 ++++--- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py index 93a4dd53f..f158b5f9b 100644 --- a/algoperf/sharding_utils.py +++ b/algoperf/sharding_utils.py @@ -17,6 +17,11 @@ def get_replicated_sharding(mesh=None): mesh = get_mesh() return NamedSharding(mesh, PartitionSpec()) +def shard_replicated(x, mesh=None): + """Shards a tensor across all devices.""" + if mesh is None: + mesh = get_mesh() + return jax.tree_map(lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x) def get_naive_sharding_spec(mesh=None): """Returns a sharding spec that shards data along the first axis.""" diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index d7285ab76..a2fc327a1 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -72,21 +72,6 @@ def _build_dataset( use_randaug=use_randaug) return ds - @functools.partial( - jax.jit, - in_shardings=( - sharding_utils.get_replicated_sharding(), # model_state - ), - out_shardings=sharding_utils.get_replicated_sharding(), - static_argnums=(0,)) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync batch statistics across replicas.""" - new_batch_stats = jax.tree_map(lambda x: jnp.mean(x, axis=0), - model_state['batch_stats']) - return model_state.copy({'batch_stats': new_batch_stats}) - - def init_model_fn( self, rng: spec.RandomState, @@ -235,7 +220,6 @@ def _compute_metrics(self, 'loss': summed_loss, 'accuracy': accuracy, } - # metrics = lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -264,12 +248,11 @@ def _eval_model_on_split(self, eval_metrics = {} for bi in range(num_batches): eval_rng = prng.fold_in(eval_rng, bi) - step_eval_rngs = prng.split(eval_rng, jax.local_device_count()) batch = next(self._eval_iters[split]) synced_metrics = self._eval_model(params, batch, model_state, - step_eval_rngs) + eval_rng) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 35a6c46be..22e714a37 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -1,5 +1,6 @@ """ImageNet workload implemented in Jax.""" +import functools from typing import Dict, Optional, Tuple from flax import jax_utils @@ -8,7 +9,7 @@ import jax import jax.numpy as jnp -from algoperf import param_utils +from algoperf import param_utils, sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload @@ -46,8 +47,8 @@ def init_model_fn( params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + params = sharding_utils.shard_replicated(params) + model_state = sharding_utils.shard_replicated(model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: From bd0f5650988a037648944dca9f64cc4bd1918e42 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 20 Feb 2025 15:45:12 +0000 Subject: [PATCH 11/68] Start working on WMT. OOM error --- algoperf/sharding_utils.py | 12 +++- algoperf/workloads/wmt/wmt_jax/workload.py | 67 ++++++++++++++++------ 2 files changed, 60 insertions(+), 19 deletions(-) diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py index f158b5f9b..4c330b767 100644 --- a/algoperf/sharding_utils.py +++ b/algoperf/sharding_utils.py @@ -35,12 +35,11 @@ def get_naive_sharding(x, mesh=None): if mesh is None: mesh = get_mesh() grid_size = mesh.shape["batch"] - if x.shape[0] % grid_size == 0: + if len(x.shape) > 0 and x.shape[0] % grid_size == 0: return NamedSharding(mesh, PartitionSpec("batch")) else: return NamedSharding(mesh, PartitionSpec()) - def shard_params(params, mesh=None): """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" if mesh is None: @@ -48,6 +47,15 @@ def shard_params(params, mesh=None): return jax.tree_util.tree_map( lambda x: jax.device_put(x, get_naive_sharding(x)), params) +def shard_naive(x, mesh=None): + return shard_params(x, mesh) + + + +def get_naive_sharding_tree(input_tree, mesh=None): + if mesh is None: + mesh = get_mesh() + return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), input_tree) def get_sharding_tree(params, mesh=None): """Returns a sharding tree for a parameter tree.""" diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cdfcb91df..138da706d 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -13,8 +13,9 @@ import numpy as np import optax -from algoperf import param_utils +from algoperf import param_utils, sharding_utils from algoperf import spec +from algoperf import sharding_utils from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_jax import decode from algoperf.workloads.wmt.wmt_jax import models @@ -69,10 +70,16 @@ def compute_weighted_cross_entropy( } @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) - def eval_step_pmapped( - self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + ), + static_argnums=(0,), # self + ) + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: """Calculate evaluation metrics on a batch.""" inputs = batch['inputs'] targets = batch['targets'] @@ -89,15 +96,23 @@ def eval_step_pmapped( 'accuracy': acc_sum, 'denominator': weight_sum, } +<<<<<<< variant A def eval_step(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: replicated_eval_metrics = self.eval_step_pmapped(params, batch) return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) +>>>>>>> variant B +======= end @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_naive_sharding_spec(), # inputs + ), + static_argnums=(0,2,) + ) def initialize_cache(self, inputs: spec.Tensor, max_decode_len: int = 256) -> Dict[str, spec.Tensor]: @@ -108,11 +123,10 @@ def initialize_cache(self, jax.random.PRNGKey(0), jnp.ones(inputs.shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) + print(initial_variables['cache']) + print(jax.tree_map(lambda x: x.shape, initial_variables['cache'])) return initial_variables['cache'] - # eos_id, max_decode_len are constant. - @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5)) def predict_step(self, inputs: spec.Tensor, params: spec.ParameterContainer, @@ -180,20 +194,34 @@ def translate_and_calculate_bleu(self, """Translates the `predict_ds` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] + jitted_predit_step = None for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - predicted = self.predict_step(pred_batch['inputs'], + if jitted_predit_step is None: + jitted_predict_step = jax.jit( + self.predict_step, + in_shardings=( + sharding_utils.get_naive_sharding_spec(), # inputs + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_tree(cache), # cache + ), + static_argnums=(3, # eos_id + 4, # max_decode_len, + 5, # beam_size + )) + predicted = jitted_predict_step(pred_batch['inputs'], params, cache, decode.EOS_ID, max_predict_length) - predicted = _to_host(predicted) - targets = _to_host(pred_batch['targets']) + # predicted = _to_host(predicted) + # targets = _to_host(pred_batch['targets']) + targets = pred_batch['targets'] # Find actual batch size, ignoring the potential padding. weights = pred_batch.get('weights') if weights is not None: - weights = _to_host(weights) + # weights = _to_host(weights) actual_batch_size = int(weights.sum(0)[0].item()) else: actual_batch_size = len(predicted) @@ -213,7 +241,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is used as attention_dropout_rate.""" - init_fake_batch_size = 2 + init_fake_batch_size = 8 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -235,15 +263,20 @@ def init_model_fn( eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) params_rng, dropout_rng = jax.random.split(rng) + inputs = jnp.ones(input_shape, jnp.float32) + targets = jnp.ones(target_shape, jnp.float32) + sharded_inputs = sharding_utils.shard_naive(inputs) + sharded_targets = sharding_utils.shard_naive(targets) + initial_variables = jax.jit( self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + sharded_inputs, sharded_targets) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + params = sharding_utils.shard_replicated(initial_params) + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'shared_embedding' From 3044efb3a0050b592beaf068848834e624460459 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 20 Feb 2025 17:37:00 +0000 Subject: [PATCH 12/68] post-rebase, still on wmt --- algoperf/workloads/mnist/mnist_jax/workload.py | 8 ++++---- algoperf/workloads/wmt/wmt_jax/workload.py | 11 +---------- .../paper_baselines/nesterov/jax/submission.py | 4 ++-- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index b57dd72dd..38f6e332f 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -10,10 +10,10 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import sharding_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload +from algoperf import param_utils +from algoperf import sharding_utils +from algoperf import spec +from algoperf.workloads.mnist.workload import BaseMnistWorkload class _Model(nn.Module): diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 138da706d..cbcd70f31 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -96,15 +96,6 @@ def eval_step(self, 'accuracy': acc_sum, 'denominator': weight_sum, } -<<<<<<< variant A - - def eval_step(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: - replicated_eval_metrics = self.eval_step_pmapped(params, batch) - return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) ->>>>>>> variant B -======= end @functools.partial( jax.jit, @@ -124,7 +115,7 @@ def initialize_cache(self, jnp.ones(inputs.shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) print(initial_variables['cache']) - print(jax.tree_map(lambda x: x.shape, initial_variables['cache'])) + print(jax.tree.map(lambda x: x.shape, initial_variables['cache'])) return initial_variables['cache'] def predict_step(self, diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 5b332030b..49e46109b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -11,8 +11,8 @@ from jax.sharding import PartitionSpec as P import optax -from algorithmic_efficiency import sharding_utils -from algorithmic_efficiency import spec +from algoperf import sharding_utils +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 From e301c492d271aef132120bd9dbdc9ad606dbae8c Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 20 Feb 2025 18:33:32 +0000 Subject: [PATCH 13/68] cache sharding fix --- algoperf/workloads/wmt/wmt_jax/workload.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cbcd70f31..b162b9e4c 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -110,12 +110,12 @@ def initialize_cache(self, """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] + dummy_inputs = sharding_utils.shard_naive(jnp.ones(inputs.shape, jnp.float32)) + dummy_targets = sharding_utils.shard_naive(jnp.ones(target_shape, jnp.float32)) initial_variables = models.Transformer(config).init( jax.random.PRNGKey(0), - jnp.ones(inputs.shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) - print(initial_variables['cache']) - print(jax.tree.map(lambda x: x.shape, initial_variables['cache'])) + dummy_inputs, + dummy_targets) return initial_variables['cache'] def predict_step(self, @@ -185,11 +185,12 @@ def translate_and_calculate_bleu(self, """Translates the `predict_ds` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] - jitted_predit_step = None + jitted_predict_step = None for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - if jitted_predit_step is None: + cache = sharding_utils.shard_naive(cache) + if jitted_predict_step is None: jitted_predict_step = jax.jit( self.predict_step, in_shardings=( From 4fcf98452d5ec77be4e076515d62b27e5c6d8386 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 21 Feb 2025 03:42:39 +0000 Subject: [PATCH 14/68] target_setting_algorithms sharding, compilation caching compilation caching really speeds things up when doing repeated runs. --- .../target_setting_algorithms/jax_adamw.py | 2 +- .../target_setting_algorithms/jax_momentum.py | 2 +- .../target_setting_algorithms/jax_nadamw.py | 2 +- .../target_setting_algorithms/jax_nesterov.py | 2 +- .../jax_submission_base.py | 71 +++++++++++++------ submission_runner.py | 6 ++ 6 files changed, 58 insertions(+), 27 deletions(-) diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index b64f0dfd6..b99d1fd94 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -41,4 +41,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index a6c3d853b..9da67e8f9 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -41,7 +41,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=False) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 597a43c9e..1ba56bbda 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -168,4 +168,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 0c11044fc..533e23f2c 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -41,7 +41,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 217228935..88718ab67 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -7,26 +7,20 @@ import jax.numpy as jnp import optax -from algoperf import spec +from algoperf import spec, sharding_utils _GRAD_CLIP_EPS = 1e-6 -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -49,9 +43,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # Compute mean loss and grad loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -98,9 +90,42 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - new_optimizer_state, new_params, new_model_state, loss, grad_norm = pmapped_train_step( # pylint: disable=line-too-long + mesh = sharding_utils.get_mesh() + # Create shardings for each argument + replicated = sharding_utils.get_replicated_sharding(mesh) # No partitioning + sharded = sharding_utils.get_naive_sharding_spec(mesh) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step( workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, grad_clip, + current_param_container, batch, rng, grad_clip, label_smoothing) # Log loss, grad_norm. @@ -108,8 +133,8 @@ def update_params( workload.metrics_logger is not None): workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/submission_runner.py b/submission_runner.py index ce61bb581..43d707519 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,8 +31,14 @@ from absl import logging import jax import tensorflow as tf +# New PRNG implementation for correct sharding jax.config.update('jax_default_prng_impl', 'threefry2x32') jax.config.update('jax_threefry_partitionable', True) +# JAX compilation caching +jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) +# jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") import torch import torch.distributed as dist From d147e3914f2495136a6c7b5c020328831afc2589 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 21 Feb 2025 04:16:22 +0000 Subject: [PATCH 15/68] Update tests to correct batch size --- .github/workflows/CI.yml | 20 ++++++++++---------- tests/reference_algorithm_tests.py | 4 +++- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 64ef0302e..ccd99e68d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -37,7 +37,7 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json wmt_pytorch: runs-on: ubuntu-latest steps: @@ -54,7 +54,7 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json imagenet_jax: runs-on: ubuntu-latest steps: @@ -71,8 +71,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json imagenet_pytorch: runs-on: ubuntu-latest steps: @@ -89,8 +89,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json # uncomment when https://github.com/mlcommons/algorithmic-efficiency/issues/339 is resolved. criteo_jax: runs-on: ubuntu-latest @@ -142,8 +142,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json speech_pytorch: runs-on: ubuntu-latest steps: @@ -160,8 +160,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json ogbg: runs-on: ubuntu-latest steps: diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index c4ca514a8..db8928309 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -225,7 +225,7 @@ def _build_input_queue(self, *args, **kwargs): del kwargs np.random.seed(42) - if framework == 'jax' or USE_PYTORCH_DDP: + if USE_PYTORCH_DDP: batch_shape = (n_gpus, global_batch_size // n_gpus) else: batch_shape = (global_batch_size,) @@ -422,6 +422,8 @@ def _test_submission(workload_name, global_batch_size = FLAGS.global_batch_size if FLAGS.global_batch_size < 0: raise ValueError('Must set --global_batch_size.') + elif global_batch_size < n_gpus and FLAGS.framework == 'jax': + raise ValueError('Global batch size cannot be smaller than the number of GPUs when using JAX sharding.') workload = _make_one_batch_workload(workload_class, workload_name, framework, From a2b61bed56f228cd152f0e2bd6de0318366d472a Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 21 Feb 2025 04:48:31 +0000 Subject: [PATCH 16/68] yapf and isort checks.. --- algoperf/checkpoint_utils.py | 1 - algoperf/data_utils.py | 4 +- algoperf/profiler.py | 4 +- algoperf/sharding_utils.py | 24 +++++---- .../workloads/cifar/cifar_jax/workload.py | 31 +---------- .../fastmri/fastmri_pytorch/workload.py | 4 +- .../imagenet_jax/custom_tf_addons.py | 46 ++++++++-------- .../imagenet_jax/randaugment.py | 8 ++- .../imagenet_resnet/imagenet_jax/workload.py | 10 ++-- .../imagenet_pytorch/workload.py | 4 +- .../imagenet_vit/imagenet_jax/workload.py | 5 +- .../librispeech_jax/workload.py | 9 ++-- .../librispeech_pytorch/workload.py | 9 ++-- .../workloads/mnist/mnist_jax/workload.py | 1 - algoperf/workloads/wmt/bleu.py | 3 +- algoperf/workloads/wmt/wmt_jax/workload.py | 54 ++++++++++--------- algoperf/workloads/wmt/wmt_pytorch/models.py | 4 +- .../external_tuning/jax_nadamw_full_budget.py | 10 ++-- .../jax_nadamw_target_setting.py | 10 ++-- .../self_tuning/jax_nadamw_full_budget.py | 10 ++-- .../self_tuning/jax_nadamw_target_setting.py | 10 ++-- .../paper_baselines/adamw/jax/submission.py | 3 +- .../paper_baselines/nadamw/jax/submission.py | 10 ++-- .../paper_baselines/sam/jax/submission.py | 8 +-- .../shampoo/jax/distributed_shampoo.py | 28 ++++------ .../target_setting_algorithms/jax_nadamw.py | 10 ++-- .../jax_submission_base.py | 10 ++-- submission_runner.py | 6 +-- tests/modeldiffs/wmt/compare.py | 6 +-- .../modeldiffs/wmt_attention_temp/compare.py | 6 +-- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 +-- tests/modeldiffs/wmt_post_ln/compare.py | 6 +-- tests/reference_algorithm_tests.py | 4 +- 33 files changed, 172 insertions(+), 192 deletions(-) diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 26ee5ee59..75d8d59ea 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -11,7 +11,6 @@ from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint -import jax import numpy as np from tensorflow.io import gfile # pytype: disable=import-error import torch diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 44eca8c26..abd3f51b3 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -11,7 +11,8 @@ from torch.utils.data import DistributedSampler from torch.utils.data import Sampler -from algoperf import spec, sharding_utils +from algoperf import sharding_utils +from algoperf import spec def shard_and_maybe_pad_np( @@ -51,6 +52,7 @@ def shard_and_maybe_pad_np( # The weights will also be padded. batch['weights'] = np.ones(mask_shape) if weights is None else weights naive_sharding_spec = sharding_utils.get_naive_sharding_spec() + def _prepare(x): # Use _numpy() for zero-copy conversion between TF and NumPy. if not isinstance(x, np.ndarray): diff --git a/algoperf/profiler.py b/algoperf/profiler.py index fa2a1bee2..d73efd964 100644 --- a/algoperf/profiler.py +++ b/algoperf/profiler.py @@ -72,8 +72,8 @@ def _make_report( float(np.std(d)), len(d), float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) for a, - d in self.recorded_durations.items()] + 100.0 * float(np.sum(d)) / total_duration) + for a, d in self.recorded_durations.items()] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py index 4c330b767..fc6b38d4a 100644 --- a/algoperf/sharding_utils.py +++ b/algoperf/sharding_utils.py @@ -1,13 +1,13 @@ """Utilities for dealing with sharding in JAX.""" import jax -from jax.sharding import Mesh from jax.sharding import NamedSharding from jax.sharding import PartitionSpec def get_mesh() -> jax.sharding.Mesh: - """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" + """Creates a mesh from all available GPUs. + Here, we simply create a one-dimensional mesh.""" return jax.sharding.Mesh(jax.devices(), ("batch",)) @@ -17,11 +17,14 @@ def get_replicated_sharding(mesh=None): mesh = get_mesh() return NamedSharding(mesh, PartitionSpec()) + def shard_replicated(x, mesh=None): """Shards a tensor across all devices.""" if mesh is None: mesh = get_mesh() - return jax.tree_map(lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x) + return jax.tree.map( + lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x) + def get_naive_sharding_spec(mesh=None): """Returns a sharding spec that shards data along the first axis.""" @@ -40,26 +43,29 @@ def get_naive_sharding(x, mesh=None): else: return NamedSharding(mesh, PartitionSpec()) + def shard_params(params, mesh=None): - """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" + """Shards a parameter tree across all devices + with naive sharding (see get_naive_sharding).""" if mesh is None: mesh = get_mesh() - return jax.tree_util.tree_map( - lambda x: jax.device_put(x, get_naive_sharding(x)), params) + return jax.tree.map(lambda x: jax.device_put(x, get_naive_sharding(x)), + params) + def shard_naive(x, mesh=None): return shard_params(x, mesh) - def get_naive_sharding_tree(input_tree, mesh=None): if mesh is None: mesh = get_mesh() - return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), input_tree) + return jax.tree.map(lambda x: get_naive_sharding(x, mesh), input_tree) + def get_sharding_tree(params, mesh=None): """Returns a sharding tree for a parameter tree.""" - return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) + return jax.tree.map(lambda x: get_naive_sharding(x, mesh), params) def get_empty_sharding(mesh=None): diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index 7dd883f1e..f827fac87 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -11,7 +11,8 @@ import optax import tensorflow_datasets as tfds -from algoperf import param_utils, sharding_utils +from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.cifar.cifar_jax import models from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter @@ -76,15 +77,9 @@ def sync_batch_stats( # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics # and we average them. -<<<<<<< variant A avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') new_model_state = model_state.copy() new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) ->>>>>>> variant B - avg_fn = jax.pmap(lambda x: lax.pmean(x, "x"), "x") - new_model_state = model_state.copy( - {"batch_stats": avg_fn(model_state["batch_stats"])}) -======= end return new_model_state def init_model_fn( @@ -102,15 +97,9 @@ def init_model_fn( input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)({"params": rng}, jnp.ones(input_shape, model.dtype)) -<<<<<<< variant A model_state, params = pop(variables, 'params') ->>>>>>> variant B - model_state, params = variables.pop("params") -======= end self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - # model_state = jax_utils.replicate(model_state) - # params = jax_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -124,10 +113,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, -<<<<<<< variant A use_running_average_bn: Optional[bool] = None ->>>>>>> variant B -======= end ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng @@ -137,13 +123,8 @@ def model_fn( variables, augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, -<<<<<<< variant A mutable=['batch_stats'], use_running_average_bn=use_running_average_bn) ->>>>>>> variant B - mutable=["batch_stats"], - ) -======= end return logits, new_model_state else: logits = self._model.apply( @@ -151,11 +132,7 @@ def model_fn( augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, mutable=False, -<<<<<<< variant A use_running_average_bn=use_running_average_bn) ->>>>>>> variant B - ) -======= end return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in @@ -251,8 +228,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" -<<<<<<< variant A - return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) ->>>>>>> variant B return jax.tree_map(lambda x: x / num_examples, total_metrics) -======= end diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 58943de2f..216a033d4 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -250,9 +250,7 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index 3d6939218..c9bf154bb 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -20,27 +20,31 @@ tf.dtypes.float64, } -Number = Union[float, - int, - np.float16, - np.float32, - np.float64, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64,] - -TensorLike = Union[List[Union[Number, list]], - tuple, - Number, - np.ndarray, - tf.Tensor, - tf.SparseTensor, - tf.Variable,] +Number = Union[ + float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, +] + +TensorLike = Union[ + List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable, +] def get_ndims(image): diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index c68e2de33..2d7e873c2 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -316,8 +316,7 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), - lambda: im, + tf.equal(step, 0), lambda: im, lambda: tf.gather(build_lut(histo, step), im)) return tf.cast(result, tf.uint8) @@ -552,7 +551,6 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): translate_const=100) image = tf.cond( tf.equal(i, op_to_select), - lambda selected_func=func, - selected_args=args: selected_func(image, *selected_args), - lambda: image) + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args), lambda: image) return image diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index a2fc327a1..6a7cf03be 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -19,8 +19,8 @@ import tensorflow_datasets as tfds from algoperf import param_utils -from algoperf import sharding_utils from algoperf import random_utils as prng +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline @@ -249,17 +249,13 @@ def _eval_model_on_split(self, for bi in range(num_batches): eval_rng = prng.fold_in(eval_rng, bi) batch = next(self._eval_iters[split]) - synced_metrics = self._eval_model(params, - batch, - model_state, - eval_rng) + synced_metrics = self._eval_model(params, batch, model_state, eval_rng) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree.map(lambda x: x / num_examples, - eval_metrics) + eval_metrics = jax.tree.map(lambda x: x / num_examples, eval_metrics) return eval_metrics diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index ed29271f3..bd07edee1 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -307,9 +307,7 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 22e714a37..1a2ce6342 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -1,15 +1,14 @@ """ImageNet workload implemented in Jax.""" -import functools from typing import Dict, Optional, Tuple -from flax import jax_utils from flax import linen as nn from flax.core import pop import jax import jax.numpy as jnp -from algoperf import param_utils, sharding_utils +from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index d68974cd8..69e4351d3 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -2,20 +2,17 @@ import math from typing import Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax.core import pop import flax.linen as nn import jax -from jax import lax import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec as P import numpy as np import optax import torch from algoperf import data_utils -from algoperf import sharding_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.librispeech_conformer import metrics from algoperf.workloads.librispeech_conformer import workload @@ -371,7 +368,7 @@ def eval_step(self, metrics_dict = self._eval_step(params, batch, model_state, rng) # Convert dictionary back to metrics bundle - metrics = self.metrics_bundle.single_from_model_output( + metrics_bundle = self.metrics_bundle.single_from_model_output( loss_dict={ 'summed': metrics_dict['loss_per_example'].sum(), 'per_example': metrics_dict['loss_per_example'], @@ -382,7 +379,7 @@ def eval_step(self, targets=metrics_dict['targets'], target_paddings=metrics_dict['target_paddings']) - return metrics + return metrics_bundle def _eval_model_on_split(self, split: str, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 5ed37957e..1dae389d7 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -259,8 +259,9 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], device=result.device).view( - 1, -1) < result.count_nonzero(dim=1).view(-1, 1) + fin_result.shape[1], + device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( + -1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -328,9 +329,7 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 38f6e332f..15fc6ff89 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -6,7 +6,6 @@ from flax import jax_utils from flax import linen as nn import jax -from jax import lax import jax.numpy as jnp import optax diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py index ad314a7d3..5e175320a 100644 --- a/algoperf/workloads/wmt/bleu.py +++ b/algoperf/workloads/wmt/bleu.py @@ -283,8 +283,7 @@ def ref_stats(output, refs): closest_diff = diff closest_len = reflen elif diff == closest_diff: - if reflen < closest_len: - closest_len = reflen + closest_len = min(reflen, closest_len) ngrams_ref = extract_ngrams(ref) for ngram in ngrams_ref: diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index b162b9e4c..dce08a677 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -13,9 +13,9 @@ import numpy as np import optax -from algoperf import param_utils, sharding_utils -from algoperf import spec +from algoperf import param_utils from algoperf import sharding_utils +from algoperf import spec from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_jax import decode from algoperf.workloads.wmt.wmt_jax import models @@ -72,10 +72,10 @@ def compute_weighted_cross_entropy( @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch ), - static_argnums=(0,), # self + static_argnums=(0,), # self ) def eval_step(self, params: spec.ParameterContainer, @@ -100,22 +100,24 @@ def eval_step(self, @functools.partial( jax.jit, in_shardings=( - sharding_utils.get_naive_sharding_spec(), # inputs + sharding_utils.get_naive_sharding_spec(), # inputs ), - static_argnums=(0,2,) - ) + static_argnums=( + 0, + 2, + )) def initialize_cache(self, inputs: spec.Tensor, max_decode_len: int = 256) -> Dict[str, spec.Tensor]: """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] - dummy_inputs = sharding_utils.shard_naive(jnp.ones(inputs.shape, jnp.float32)) - dummy_targets = sharding_utils.shard_naive(jnp.ones(target_shape, jnp.float32)) + dummy_inputs = sharding_utils.shard_naive( + jnp.ones(inputs.shape, jnp.float32)) + dummy_targets = sharding_utils.shard_naive( + jnp.ones(target_shape, jnp.float32)) initial_variables = models.Transformer(config).init( - jax.random.PRNGKey(0), - dummy_inputs, - dummy_targets) + jax.random.PRNGKey(0), dummy_inputs, dummy_targets) return initial_variables['cache'] def predict_step(self, @@ -194,19 +196,20 @@ def translate_and_calculate_bleu(self, jitted_predict_step = jax.jit( self.predict_step, in_shardings=( - sharding_utils.get_naive_sharding_spec(), # inputs - sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_tree(cache), # cache + sharding_utils.get_naive_sharding_spec(), # inputs + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_tree(cache), # cache ), - static_argnums=(3, # eos_id - 4, # max_decode_len, - 5, # beam_size - )) + static_argnums=( + 3, # eos_id + 4, # max_decode_len, + 5, # beam_size + )) predicted = jitted_predict_step(pred_batch['inputs'], - params, - cache, - decode.EOS_ID, - max_predict_length) + params, + cache, + decode.EOS_ID, + max_predict_length) # predicted = _to_host(predicted) # targets = _to_host(pred_batch['targets']) targets = pred_batch['targets'] @@ -262,7 +265,8 @@ def init_model_fn( initial_variables = jax.jit( self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, - sharded_inputs, sharded_targets) + sharded_inputs, + sharded_targets) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py index a1c7ce15e..089f1bfbb 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models.py @@ -942,8 +942,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) + >= cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index c451a18ac..ade3e9f63 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index b8ac10f33..d3a27411c 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 78c3b5b3e..9bc014ed1 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -132,8 +132,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -148,8 +149,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index ffe854a0e..a6781612c 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -132,8 +132,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -148,8 +149,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 156c7ab20..d4dd0539b 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -11,7 +11,8 @@ from jax.sharding import PartitionSpec as P import optax -from algoperf import spec, sharding_utils +from algoperf import sharding_utils +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index c451a18ac..ade3e9f63 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index b76589705..ce2d1feef 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -67,8 +67,9 @@ def update_fn(updates, state, grad_fn_params_tuple): # the noised parameters in the same order as on the original gradients and # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map( - lambda p, u: p + rho * u, params, updates) + noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, + params, + updates) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), @@ -80,8 +81,7 @@ def update_fn(updates, state, grad_fn_params_tuple): sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) scaled_updates = jax.tree.map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, - lambda _: scaled_updates, + updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, lambda _: updates, None) updates, state = base_opt_update_fn(updates, state, params) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index a5c2732ac..4f670a85b 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1809,17 +1809,13 @@ def sharded_update_fn(grads, state, params): )) new_stats_flat = jax.tree.map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), + lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) outputs = jax.tree.map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), + lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) @@ -1923,8 +1919,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), - errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), errors + >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + @@ -2442,9 +2438,7 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree.map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), + lambda g, s, p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) @@ -2453,9 +2447,7 @@ def update_fn(grads, state, params): params_flat, state.count) outputs = jax.tree.map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), + lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 1ba56bbda..fbb7a612c 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -108,8 +108,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -124,8 +125,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 88718ab67..1cfa5deca 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -7,7 +7,8 @@ import jax.numpy as jnp import optax -from algoperf import spec, sharding_utils +from algoperf import sharding_utils +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -93,7 +94,8 @@ def update_params( mesh = sharding_utils.get_mesh() # Create shardings for each argument replicated = sharding_utils.get_replicated_sharding(mesh) # No partitioning - sharded = sharding_utils.get_naive_sharding_spec(mesh) # Partition along batch dimension + sharded = sharding_utils.get_naive_sharding_spec( + mesh) # Partition along batch dimension # Create the sharding rules for each argument arg_shardings = ( @@ -114,7 +116,7 @@ def update_params( replicated, # loss replicated # grad_norm ) - + # Jit with shardings jitted_train_step = jax.jit( train_step, @@ -122,7 +124,7 @@ def update_params( donate_argnums=(2, 3, 4), in_shardings=arg_shardings, out_shardings=out_shardings) - + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step( workload, opt_update_fn, model_state, optimizer_state, current_param_container, batch, rng, grad_clip, diff --git a/submission_runner.py b/submission_runner.py index 43d707519..eace6054d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,6 +31,7 @@ from absl import logging import jax import tensorflow as tf + # New PRNG implementation for correct sharding jax.config.update('jax_default_prng_impl', 'threefry2x32') jax.config.update('jax_threefry_partitionable', True) @@ -38,7 +39,6 @@ jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) -# jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") import torch import torch.distributed as dist @@ -392,8 +392,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 64401ef7f..656d07fb5 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -75,9 +75,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 01dc2895c..d30474f00 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index 77e71c826..4b0a0b218 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 909fcd672..818eec672 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index db8928309..f576d136b 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -423,7 +423,9 @@ def _test_submission(workload_name, if FLAGS.global_batch_size < 0: raise ValueError('Must set --global_batch_size.') elif global_batch_size < n_gpus and FLAGS.framework == 'jax': - raise ValueError('Global batch size cannot be smaller than the number of GPUs when using JAX sharding.') + raise ValueError( + 'Global batch size cannot be smaller than the number of GPUs when using JAX sharding.' + ) workload = _make_one_batch_workload(workload_class, workload_name, framework, From a80f4ec68a5affbe1386adbe649dd5dcde8310a4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 7 Mar 2025 19:43:02 +0000 Subject: [PATCH 17/68] switch fastmri from pmap to jit --- .../workloads/fastmri/fastmri_jax/workload.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 1156cf30a..cbe961a09 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -10,6 +10,7 @@ from algoperf import param_utils from algoperf import spec +from algoperf import sharding_utils import algoperf.random_utils as prng from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim @@ -39,7 +40,7 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = jax_utils.replicate(params) + params = sharding_utils.shard_replicated(params) return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -94,10 +95,15 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + sharding_utils.get_replicated_sharding() + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding() + ) def _eval_model(self, params: spec.Tensor, batch: Dict[str, spec.Tensor], @@ -126,7 +132,6 @@ def _eval_model(self, 'ssim': ssim_sum, 'loss': summed_loss, } - metrics = jax.lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -154,13 +159,12 @@ def _eval_model_on_split(self, num_batches=num_batches) total_metrics = {'ssim': 0., 'loss': 0.} - eval_rngs = prng.split(model_rng, jax.local_device_count()) for _ in range(num_batches): batch = next(self._eval_iters[split]) # We already sum these metrics across devices inside _eval_model. - synced_metrics = self._eval_model(params, batch, eval_rngs) + synced_metrics = self._eval_model(params, batch, model_rng) total_metrics = { - k: v + synced_metrics[k][0] for k, v in total_metrics.items() + k: v + synced_metrics[k] for k, v in total_metrics.items() } return {k: float(v.item() / num_examples) for k, v in total_metrics.items()} From c39ca51d5b00f7752538a3d3d5e35d5870f3462e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 7 Mar 2025 20:48:09 +0000 Subject: [PATCH 18/68] migrate criteo workload --- .../criteo1tb/criteo1tb_jax/workload.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 91761e458..2f41eb8c6 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -11,6 +11,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.workloads.criteo1tb.criteo1tb_jax import models +from algoperf import sharding_utils from algoperf.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload @@ -105,7 +106,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return sharding_utils.shard_replicated(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_7' @@ -129,11 +130,14 @@ def model_fn( return logits_batch, None @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_batch_pmapped(self, + jax.jit, + in_shardings=(sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding() + ) + def _eval_batch_jitted(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> spec.Tensor: logits, _ = self.model_fn( @@ -157,7 +161,7 @@ def _eval_batch(self, # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. return np.array( - self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) + self._eval_batch_jitted(params, batch), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): From 06377d981b207dbcf8722b8dc36dd4cd474267e6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 7 Mar 2025 20:48:43 +0000 Subject: [PATCH 19/68] update utils function used for sharding conformer --- .../librispeech_conformer/librispeech_jax/workload.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 69e4351d3..cd04d4746 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -94,15 +94,8 @@ def init_model_fn( self._param_types = param_utils.jax_param_types(self._param_shapes) # Add sharding - mesh = sharding_utils.get_mesh() - params = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) - ), - params) - model_state = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) - ), - model_state) + params = sharding_utils.shard_replicated(params) + model_state = sharding_utils.shard_replicated(model_state) return params, model_state From 9cbe7d92e136d22977dce18c61c5712d9b106ba4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 8 Mar 2025 07:21:24 +0000 Subject: [PATCH 20/68] update conformer and deepspeech --- .../librispeech_jax/workload.py | 21 +---------------- .../librispeech_jax/models.py | 4 ---- .../librispeech_jax/workload.py | 5 ++-- .../paper_baselines/adamw/jax/submission.py | 23 ++++++++----------- 4 files changed, 14 insertions(+), 39 deletions(-) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index cd04d4746..8e3f3d975 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -333,7 +333,6 @@ def _eval_step( decoded, decoded_paddings = self.greedy_decode(logits, logit_paddings) loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) - targets, target_paddings = batch['targets'] # Convert metrics bundle to dictionary metrics_dict = { @@ -385,9 +384,6 @@ def _eval_model_on_split(self, global_step: int = 0) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step - if model_state is not None and len(model_state) > 0: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: @@ -409,21 +405,6 @@ def _eval_model_on_split(self, return computed_metrics - @functools.partial( - jax.jit, - in_shardings=( - sharding_utils.get_replicated_sharding(), # model_state - ), - out_shardings=sharding_utils.get_replicated_sharding(), - static_argnums=(0,)) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync batch statistics across replicas.""" - # Replace pmean with direct mean across devices - new_batch_stats = jax.tree_map(lambda x: jnp.mean(x, axis=0), - model_state['batch_stats']) - return model_state.copy({'batch_stats': new_batch_stats}) - class LibriSpeechConformerAttentionTemperatureWorkload( LibriSpeechConformerWorkload): @@ -465,7 +446,7 @@ def use_gelu(self) -> bool: @property def validation_target_value(self) -> float: return 0.094114 - + @property def test_target_value(self) -> float: return 0.056629 diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index c6e7275be..07481272f 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -308,16 +308,12 @@ def __call__(self, inputs, input_paddings=None, train=False): count_v = jnp.sum( jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) - sum_v = jax.lax.psum(sum_v, axis_name='batch') - count_v = jax.lax.psum(count_v, axis_name='batch') - count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v variance = (inputs - mean) * (inputs - mean) * mask sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True) - sum_vv = jax.lax.psum(sum_vv, axis_name='batch') var = sum_vv / count_v self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d3b616f43..392de7b89 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -8,6 +8,7 @@ from algoperf import param_utils from algoperf import spec +from algoperf import sharding_utils from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models @@ -51,8 +52,8 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + model_state = sharding_utils.shard_replicated(model_state) + params = sharding_utils.shard_replicated(params) return params, model_state def model_fn( diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index d4dd0539b..5baa046f5 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -55,16 +55,15 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): return optimizer_state, opt_update_fn - def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -159,15 +158,13 @@ def update_params( replicated, # loss replicated # grad_norm ) - - # Jit with shardings jitted_train_step = jax.jit( train_step, static_argnums=(0, 1), donate_argnums=(2, 3, 4), in_shardings=arg_shardings, - out_shardings=out_shardings) - + out_shardings=out_shardings + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, From c6ecd676da7ebdc63bb43eddff7b08b36b6e2deb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Mar 2025 17:16:31 +0000 Subject: [PATCH 21/68] debugging --- .../librispeech_jax/models.py | 339 +++++++++++++++++- 1 file changed, 335 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 07481272f..5d6be2f5d 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -8,8 +8,10 @@ # webpage : https://bastings.github.io/ """ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Type, Mapping, Sequence +import functools +import flax from flax import linen as nn from flax import struct import jax @@ -418,6 +420,321 @@ def unpack_weights( self.bidirectional, ) +### Swap in regular LSTM layer for debuggin +@jax.vmap +def flip_sequences(inputs: Array, lengths: Array) -> Array: + """Flips a sequence of inputs along the time dimension. + + This function can be used to prepare inputs for the reverse direction of a + bidirectional LSTM. It solves the issue that, when naively flipping multiple + padded sequences stored in a matrix, the first elements would be padding + values for those sequences that were padded. This function keeps the padding + at the end, while flipping the rest of the elements. + + Example: + ```python + inputs = [[1, 0, 0], + [2, 3, 0] + [4, 5, 6]] + lengths = [1, 2, 3] + flip_sequences(inputs, lengths) = [[1, 0, 0], + [3, 2, 0], + [6, 5, 4]] + ``` + + Args: + inputs: An array of input IDs [batch_size, seq_length]. + lengths: The length of each sequence [batch_size]. + + Returns: + An ndarray with the flipped inputs. + """ + # Compute the indices to put the inputs in flipped order as per above example. + max_length = inputs.shape[0] + idxs = (jnp.arange(max_length - 1, -1, -1) + lengths) % max_length + return inputs[idxs] + +class GenericRNNSequenceEncoder(nn.Module): + """Encodes a single sequence using any RNN cell, for example `nn.LSTMCell`. + + The sequence can be encoded left-to-right (default) or right-to-left (by + calling the module with reverse=True). Regardless of encoding direction, + outputs[i, j, ...] is the representation of inputs[i, j, ...]. + + Attributes: + hidden_size: The hidden size of the RNN cell. + cell_type: The RNN cell module to use, for example, `nn.LSTMCell`. + cell_kwargs: Optional keyword arguments for the recurrent cell. + recurrent_dropout_rate: The dropout to apply across time steps. If this is + greater than zero, you must use an RNN cell that implements + `RecurrentDropoutCell` such as RecurrentDropoutOptimizedLSTMCell. + """ + hidden_size: int + cell_type: Type[nn.RNNCellBase] + cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() + recurrent_dropout_rate: float = 0.0 + + def setup(self): + self.cell = self.cell_type(features=self.hidden_size, **self.cell_kwargs) + + @functools.partial( # Repeatedly calls the below method to encode the inputs. + nn.transforms.scan, + variable_broadcast='params', + in_axes=(1, flax.core.axes_scan.broadcast, flax.core.axes_scan.broadcast), + out_axes=1, + split_rngs={'params': False}) + def unroll_cell(self, cell_state: StateType, inputs: Array, + recurrent_dropout_mask: Optional[Array], deterministic: bool): + """Unrolls a recurrent cell over an input sequence. + + Args: + cell_state: The initial cell state, shape: [batch_size, + hidden_size] (or an n-tuple thereof). + inputs: The input sequence. [batch_size, seq_len, input_dim]. + recurrent_dropout_mask: An optional recurrent dropout mask to apply in + between time steps. [batch_size, hidden_size]. + deterministic: Disables recurrent dropout when set to True. + + Returns: + The cell state after processing the complete sequence (including padding), + and a tuple with all intermediate cell states and cell outputs. + """ + # We do not directly scan the cell itself, since it only returns the output. + # This returns both the state and the output, so we can slice out the + # correct final states later. + new_cell_state, output = self.cell(cell_state, inputs) + return new_cell_state, (new_cell_state, output) + + def __call__(self, + inputs: Array, + lengths: Array, + initial_state: StateType, + reverse: bool = False, + deterministic: bool = False): + """Unrolls the RNN cell over the inputs. + + Arguments: + inputs: A batch of sequences. Shape: [batch_size, seq_len, + input_dim]. + lengths: The lengths of the input sequences. + initial_state: The initial state for the RNN cell. Shape: [batch_size, + hidden_size]. + reverse: Process the inputs in reverse order, and reverse the outputs. + This means that the outputs still correspond to the order of the inputs, + but their contexts come from the right, and not from the left. + deterministic: Disables recurrent dropout if set to True. + + Returns: + The encoded sequence of inputs, shaped [batch_size, seq_len, + hidden_size], as well as the final hidden states of the RNN cell. For an + LSTM cell the final states are a tuple (c, h), each shaped [ + batch_size, hidden_size]. + """ + if reverse: + inputs = flip_sequences(inputs, lengths) + + recurrent_dropout_mask = None + _, (cell_states, outputs) = self.unroll_cell(initial_state, inputs, + recurrent_dropout_mask, + deterministic) + final_state = jax.tree.map( + lambda x: x[jnp.arange(inputs.shape[0]), lengths - 1], cell_states) + + if reverse: + outputs = flip_sequences(outputs, lengths) + + return outputs, final_state + + +class GenericRNN(nn.Module): + """Generic RNN class. + + This provides generic RNN functionality to encode sequences with any RNN cell. + The class provides unidirectional and bidirectional layers, and these are + stacked when asking for multiple layers. + + This class be used to create a specific RNN class such as LSTM or GRU. + + Attributes: + cell_type: An RNN cell class to use, e.g., `flax.linen.LSTMCell`. + hidden_size: The size of each recurrent cell. + num_layers: The number of stacked recurrent layers. The output of the first + layer, with optional dropout applied, feeds into the next layer. + dropout_rate: Dropout rate to be applied between LSTM layers. Only applies + when num_layers > 1. + recurrent_dropout_rate: Dropout rate to be applied on the hidden state at + each time step repeating the same dropout mask. + bidirectional: Process the sequence left-to-right and right-to-left and + concatenate the outputs from the two directions. + cell_kwargs: Optional keyword arguments to instantiate the cell with. + """ + cell_type: Type[nn.RNNCellBase] + hidden_size: int + num_layers: int = 1 + dropout_rate: float = 0. + recurrent_dropout_rate: float = 0. + bidirectional: bool = False + cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() + + @nn.compact + def __call__( + self, + inputs: Array, + lengths: Array, + initial_states: Optional[Sequence[StateType]] = None, + deterministic: bool = False + ) -> Tuple[Array, Sequence[StateType]]: + """Processes the input sequence using the recurrent cell. + + Args: + inputs: The input sequence [batch_size, sequence_length, ...] + lengths: The lengths of each sequence in the batch. [batch_size] + initial_states: The initial states for the cells. You must provide + `num_layers` initial states (when using bidirectional, `num_layers * + 2`). + These must be ordered in the following way: (layer_0_forward, + layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, + all initial states will be initialized with zeros. + deterministic: Disables dropout between layers when set to True. + Returns: + The sequence of all outputs for the final layer, and a list of final + states for each cell and direction. Directions are alternated (first + forward, then backward, if bidirectional). For example for a bidirectional + cell this would be: layer 1 forward, layer 1 backward, layer 2 forward, + layer 2 backward, etc.. + For some cells like LSTMCell a state consists of an (c, h) tuple, while + for others cells it only contains a single vector (h,). + """ + batch_size = inputs.shape[0] + final_states = [] + num_directions = 2 if self.bidirectional else 1 + num_cells = self.num_layers * num_directions + + # Construct initial states. + if initial_states is None: # Initialize with zeros. + rng = jax.random.PRNGKey(0) + initial_states = [ + self.cell_type(self.hidden_size).initialize_carry( + rng, (batch_size, 1) + ) + for _ in range(num_cells) + ] + if len(initial_states) != num_cells: + raise ValueError( + f'Please provide {self.num_cells} (`num_layers`, *2 if bidirectional)' + 'initial states.' + ) + + # For each layer, apply the forward and optionally the backward RNN cell. + cell_idx = 0 + for _ in range(self.num_layers): + # Unroll an RNN cell (forward direction) for this layer. + outputs, final_state = GenericRNNSequenceEncoder( + cell_type=self.cell_type, + cell_kwargs=self.cell_kwargs, + hidden_size=self.hidden_size, + recurrent_dropout_rate=self.recurrent_dropout_rate, + name=f'{self.name}SequenceEncoder_{cell_idx}')( + inputs, + lengths, + initial_state=initial_states[cell_idx], + deterministic=deterministic) + final_states.append(final_state) + cell_idx += 1 + + # Unroll an RNN cell (backward direction) for this layer. + if self.bidirectional: + backward_outputs, backward_final_state = GenericRNNSequenceEncoder( + cell_type=self.cell_type, + cell_kwargs=self.cell_kwargs, + hidden_size=self.hidden_size, + recurrent_dropout_rate=self.recurrent_dropout_rate, + name=f'{self.name}SequenceEncoder_{cell_idx}')( + inputs, + lengths, + initial_state=initial_states[cell_idx], + reverse=True, + deterministic=deterministic) + outputs = jnp.concatenate([outputs, backward_outputs], axis=-1) + final_states.append(backward_final_state) + cell_idx += 1 + + inputs = outputs + + return outputs, final_states + + +class LSTM(nn.Module): + """LSTM. + + Attributes: + hidden_size: The size of each recurrent cell. + num_layers: The number of stacked recurrent layers. The output of the first + layer, with optional dropout applied, feeds into the next layer. + dropout_rate: Dropout rate to be applied between LSTM layers. Only applies + when num_layers > 1. + recurrent_dropout_rate: Dropout rate to be applied on the hidden state at + each time step repeating the same dropout mask. + bidirectional: Process the sequence left-to-right and right-to-left and + concatenate the outputs from the two directions. + cell_type: The LSTM cell class to use. Default: + `flax.linen.OptimizedLSTMCell`. If you use hidden_size of >2048, consider + using `flax.linen.LSTMCell` instead, since the optimized LSTM cell works + best for hidden sizes up to 2048. + cell_kwargs: Optional keyword arguments to instantiate the cell with. + """ + hidden_size: int + num_layers: int = 1 + dropout_rate: float = 0. + recurrent_dropout_rate: float = 0. + bidirectional: bool = False + cell_type: Any = nn.OptimizedLSTMCell + cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() + + @nn.compact + def __call__( + self, + inputs: Array, + lengths: Array, + initial_states: Optional[Sequence[StateType]] = None, + deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: + """Processes an input sequence with an LSTM cell. + + Example usage: + ``` + inputs = np.random.normal(size=(2, 3, 4)) + lengths = np.array([1, 3]) + outputs, final_states = LSTM(hidden_size=10).apply(rngs, inputs, lengths) + ``` + + Args: + inputs: The input sequence [batch_size, sequence_length, ...] + lengths: The lengths of each sequence in the batch. [batch_size] + initial_states: The initial states for the cells. You must provide + `num_layers` initial states (when using bidirectional, `num_layers * + 2`). These must be ordered in the following way: (layer_0_forward, + layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, + all initial states will be initialized with zeros. + deterministic: Disables dropout between layers when set to True. + + Returns: + The sequence of all outputs for the final layer, and a list of final + states (h, c) for each cell and direction, ordered first by layer number + and then by direction (first forward, then backward, if bidirectional). + """ + return GenericRNN( + cell_type=self.cell_type, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + dropout_rate=self.dropout_rate, + recurrent_dropout_rate=self.recurrent_dropout_rate, + bidirectional=self.bidirectional, + cell_kwargs=self.cell_kwargs, + name='LSTM')( + inputs, + lengths, + initial_states=initial_states, + deterministic=deterministic) class BatchRNN(nn.Module): """Implements a single deepspeech encoder layer. @@ -437,10 +754,24 @@ def __call__(self, inputs, input_paddings, train): config.batch_norm_epsilon)(inputs, input_paddings, train) - output = CudnnLSTM( - features=config.encoder_dim // 2, + + # For regular LSTM + hidden_size = ( + config.encoder_dim // 2 if config.bidirectional else config.encoder_dim + ) + lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32) + + output, _ = LSTM( + hidden_size=hidden_size, bidirectional=config.bidirectional, - num_layers=1)(inputs, input_paddings) + num_layers=1, + )(inputs, lengths) + + # output = CudnnLSTM( + # features=config.encoder_dim // 2, + # bidirectional=config.bidirectional, + # num_layers=1)(inputs, input_paddings) + return output From f35690de13d415ce00619ad66676bb3cd2bc16a7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 12 Mar 2025 20:03:07 +0000 Subject: [PATCH 22/68] debuging --- .../librispeech_jax/models.py | 38 ++++++++++++++----- .../paper_baselines/adamw/jax/submission.py | 2 +- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 5d6be2f5d..e1d76730d 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -9,6 +9,9 @@ """ from typing import Any, Dict, List, Optional, Tuple, Union, Type, Mapping, Sequence +from absl import logging + +import numpy as np import functools import flax @@ -385,6 +388,23 @@ def __call__( seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32) if use_cuda: + inputs_shape = np.shape(inputs) + h_0_shape = np.shape(h_0) + c_0_shape = np.shape(c_0) + weights_shape = np.shape(weights) + seq_lengths_np = np.shape(seq_lengths) + + n = jax.devices() + logging.info(f"jax num devices {n}") + logging.info(f'inputs shape {inputs_shape}') + logging.info(f'h_0 shape {h_0_shape}') + logging.info(f'c_0 shape {c_0_shape}') + logging.info(f'seq_lengths shape {seq_lengths_np}') + logging.info(f'weights_shape {weights_shape}') + logging.info(f'input_size {input_size}') + logging.info(f'hidden_size {self.features}') + logging.info(f'num_layers {self.num_layers}') + y, h, c = rnn.lstm( x=inputs, h_0=h_0, c_0=c_0, weights=weights, seq_lengths=seq_lengths, input_size=input_size, @@ -761,16 +781,16 @@ def __call__(self, inputs, input_paddings, train): ) lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32) - output, _ = LSTM( - hidden_size=hidden_size, - bidirectional=config.bidirectional, - num_layers=1, - )(inputs, lengths) - - # output = CudnnLSTM( - # features=config.encoder_dim // 2, + # output, _ = LSTM( + # hidden_size=hidden_size, # bidirectional=config.bidirectional, - # num_layers=1)(inputs, input_paddings) + # num_layers=1, + # )(inputs, lengths) + + output = CudnnLSTM( + features=config.encoder_dim // 2, + bidirectional=config.bidirectional, + num_layers=1)(inputs, input_paddings) return output diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 5baa046f5..d381d3dfd 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -223,7 +223,7 @@ def get_batch_size(workload_name): elif workload_name == 'librispeech_conformer': return 256 elif workload_name == 'librispeech_deepspeech': - return 256 + return 32 elif workload_name == 'ogbg': return 512 elif workload_name == 'wmt': From da5f85a7c878f0399c7b8a5d2fcfb9d729e567ea Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 11 Mar 2025 15:46:49 +0100 Subject: [PATCH 23/68] first LM commit --- algoperf/workloads/lm/__init__.py | 0 algoperf/workloads/lm/dev/data_pytorch.py | 42 ++++++++++ algoperf/workloads/lm/input_pipeline.py | 82 ++++++++++++++++++++ algoperf/workloads/lm/lm_pytorch/__init__.py | 0 algoperf/workloads/lm/lm_pytorch/workload.py | 36 +++++++++ algoperf/workloads/lm/test_01.py | 22 ++++++ algoperf/workloads/lm/test_input_pipeline.py | 68 ++++++++++++++++ algoperf/workloads/lm/workload.py | 66 ++++++++++++++++ 8 files changed, 316 insertions(+) create mode 100644 algoperf/workloads/lm/__init__.py create mode 100644 algoperf/workloads/lm/dev/data_pytorch.py create mode 100644 algoperf/workloads/lm/input_pipeline.py create mode 100644 algoperf/workloads/lm/lm_pytorch/__init__.py create mode 100644 algoperf/workloads/lm/lm_pytorch/workload.py create mode 100644 algoperf/workloads/lm/test_01.py create mode 100644 algoperf/workloads/lm/test_input_pipeline.py create mode 100644 algoperf/workloads/lm/workload.py diff --git a/algoperf/workloads/lm/__init__.py b/algoperf/workloads/lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py new file mode 100644 index 000000000..d0081a75d --- /dev/null +++ b/algoperf/workloads/lm/dev/data_pytorch.py @@ -0,0 +1,42 @@ + +import torch + +from datasets import Dataset, load_from_disk +from torch.utils.data import DataLoader + +trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" +vocab_size = 50280 +seq_len = 2048 +sampler = 'sequential' +sampler_seed = None +num_workers = 4 + +train_set = load_from_disk(trainset_path) # + +""" +>>> type(train_set) + + +>>> len(train_set) +7501407 + +>>> train_set[0] +{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} + +>>> type(train_set[0]['input_ids']) + + +# In PyTorch we do: +trainloader = DataLoader( + train_set, + sampler = ..., + batch_size = ..., + num_workers = ..., + pin_memory = ..., + ) + +# PyTorch’s DataLoader expects an iterable dataset, +# which means it calls __getitem__() and __len__() on train_set. + +""" + diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py new file mode 100644 index 000000000..7424dd6d5 --- /dev/null +++ b/algoperf/workloads/lm/input_pipeline.py @@ -0,0 +1,82 @@ +"""Input pipeline for a LM dataset.""" +import functools +import os + +from datasets import Dataset, load_from_disk +from typing import Dict, List, Optional, Union + +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from algoperf import data_utils +from algoperf.pytorch_utils import pytorch_setup + +RANK = pytorch_setup()[1] +# Avoid multithreading in all processes but the first (rank 0). +AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None + + +def get_lm_dataset(data_rng, + split: str, + data_dir: str, + is_training: bool, + vocab_size: int, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + vocab_path: Optional[str] = None): + """Load HF dataset and return a TF dataset.""" + + dataset_path = os.path.join(data_dir, split) + dataset = load_from_disk(dataset_path) # Loads HF arrow dataset + + is_training = split == "train" + shuffle = split in ['train', 'eval_train'] + + def tf_generator(): + """Generates data in a TensorFlow-friendly format.""" + for example in dataset: + yield { + "inputs": tf.convert_to_tensor(example["input_ids"][:-1], dtype=tf.int32), + "targets": tf.convert_to_tensor(example["input_ids"][1:], dtype=tf.int32), + } + + # Create a TensorFlow dataset from the generator function + ds = tf.data.Dataset.from_generator( + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), + } + ) + + # Avoid creating too many threads when using PyTorch DDP. + if RANK != 0: + options = tf.data.Options() + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) + + if shuffle: + print(f"Shuffling dataset with seed: {data_rng[0]}, type={type(data_rng[0])}") + ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) + + if is_training: + ds = ds.repeat() + + # Batch the dataset, ensuring the last batch is dropped if not full during training + ds = ds.batch(global_batch_size, drop_remainder=is_training) + ds = ds.prefetch(AUTOTUNE) + + # Limit the dataset to a fixed number of batches if `num_batches` is specified + if num_batches: + ds = ds.take(num_batches) + + # Shard the dataset across multiple GPUs/TPUs if necessary + ds = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, + global_batch_size=global_batch_size), + ds) + + return ds \ No newline at end of file diff --git a/algoperf/workloads/lm/lm_pytorch/__init__.py b/algoperf/workloads/lm/lm_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py new file mode 100644 index 000000000..904657b1d --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -0,0 +1,36 @@ +"""LM workload implemented in PyTorch.""" + +import contextlib +from typing import Any, Dict, Optional, Tuple + +from absl import logging +import jax +import tensorflow as tf +import torch +import torch.distributed as dist +from torch.nn import DataParallel as DP +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP + +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + +class LmWorkload(BaseLmWorkload): + """LM PyTorch workload.""" + + def init_model_fn(): + pass + + def model_fn(): + pass + + def _build_input_queue(): + pass + + def eval_step(): + pass diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/test_01.py new file mode 100644 index 000000000..e33ddf3e7 --- /dev/null +++ b/algoperf/workloads/lm/test_01.py @@ -0,0 +1,22 @@ +import os +import tensorflow as tf +import torch +from datasets import load_from_disk + +from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" +BATCH_SIZE = 2 +SEED = 42 # Fixed random seed for reproducibility + +tf_seed = SEED + +# Load the dataset +ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, +) diff --git a/algoperf/workloads/lm/test_input_pipeline.py b/algoperf/workloads/lm/test_input_pipeline.py new file mode 100644 index 000000000..47c11969f --- /dev/null +++ b/algoperf/workloads/lm/test_input_pipeline.py @@ -0,0 +1,68 @@ +import os +import tensorflow as tf +import torch +from datasets import load_from_disk + +from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" +BATCH_SIZE = 2 +SEED = 42 # Fixed random seed for reproducibility + + +def test_tf_dataset(): + """Tests if get_lm_dataset correctly loads the HF dataset as a TensorFlow dataset.""" + + print(f"Loading dataset from: {DATASET_PATH}") + + tf_seed = SEED + + # Load the dataset + ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, + ) + + print("Testing TensorFlow Dataset Output...") + for batch in ds.take(2): # Take two batches to test + print("Inputs:", batch["inputs"].numpy()) # Convert to NumPy for inspection + print("Targets:", batch["targets"].numpy()) + +def test_pytorch_dataloader(): + """Tests if the TensorFlow dataset can be converted to PyTorch format correctly.""" + + # Use the same TensorFlow-compatible seed + tf_seed = tf.constant(SEED, dtype=tf.int64) + + # Load the dataset + ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, + global_batch_size=BATCH_SIZE, + ) + + def _input_queue_generator(): + """Generator that converts TF dataset batches to PyTorch tensors.""" + for batch in iter(ds): + batch = {k: torch.tensor(v.numpy()) for k, v in batch.items()} # Convert to PyTorch tensors + yield batch + + dataloader = _input_queue_generator() + + print("\nTesting PyTorch DataLoader Output...") + for _ in range(2): # Take two batches + batch = next(dataloader) + print("Inputs:", batch["inputs"]) + print("Targets:", batch["targets"]) + +# Run tests +if __name__ == "__main__": + test_tf_dataset() + test_pytorch_dataloader() \ No newline at end of file diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py new file mode 100644 index 000000000..d070cabec --- /dev/null +++ b/algoperf/workloads/lm/workload.py @@ -0,0 +1,66 @@ +"""LM workload parent class.""" + +import abc +import math +import os +from typing import Any, Dict, Optional, Tuple + +import jax +import numpy as np +import torch + +from algoperf import spec +from algoperf.workloads.lm import input_pipeline + +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ + + +class BaseLmWorkload(spec.Workload): + """A LM workload.""" + + _vocab_size: int = 32000 + + def __init__(self) -> None: + super().__init__() + self._tokenizer = None + + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + is_training = split == 'train' + ds, self._tokenizer = input_pipeline.get_lm_dataset( + data_rng, + split, + data_dir, + is_training=is_training, + vocab_size=self._vocab_size, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + + for batch in iter(ds): + yield batch + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the loss function at (label_batch, logits_batch).""" + pass \ No newline at end of file From a12a36404ce907c8e50e67c8e4a5eb25baa9a2f3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 12 Mar 2025 15:49:04 +0100 Subject: [PATCH 24/68] lm data pipeline --- algoperf/workloads/lm/input_pipeline.py | 11 +-- algoperf/workloads/lm/test_01.py | 96 +++++++++++++++++++++---- datasets/dataset_setup.py | 96 +++++++++++++++++++++++++ datasets/lm_preprocess.py | 0 4 files changed, 185 insertions(+), 18 deletions(-) create mode 100644 datasets/lm_preprocess.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 7424dd6d5..a14cebeda 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -5,6 +5,7 @@ from datasets import Dataset, load_from_disk from typing import Dict, List, Optional, Union +import jax import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -17,7 +18,7 @@ AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None -def get_lm_dataset(data_rng, +def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, is_training: bool, @@ -37,11 +38,12 @@ def get_lm_dataset(data_rng, def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: + input_ids = example["input_ids"].numpy().astype(np.int32) # torch tensor TODO: remove numpy conversion yield { - "inputs": tf.convert_to_tensor(example["input_ids"][:-1], dtype=tf.int32), - "targets": tf.convert_to_tensor(example["input_ids"][1:], dtype=tf.int32), + "inputs": tf.convert_to_tensor(input_ids[:-1], dtype=tf.int32), + "targets": tf.convert_to_tensor(input_ids[1:], dtype=tf.int32), } - + # Create a TensorFlow dataset from the generator function ds = tf.data.Dataset.from_generator( tf_generator, @@ -58,7 +60,6 @@ def tf_generator(): ds = ds.with_options(options) if shuffle: - print(f"Shuffling dataset with seed: {data_rng[0]}, type={type(data_rng[0])}") ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) if is_training: diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/test_01.py index e33ddf3e7..977fae11a 100644 --- a/algoperf/workloads/lm/test_01.py +++ b/algoperf/workloads/lm/test_01.py @@ -1,22 +1,92 @@ + import os +import numpy as np import tensorflow as tf import torch + from datasets import load_from_disk +from absl import app +from absl import flags +from absl import logging + +from algoperf.profiler import PassThroughProfiler +from algoperf import random_utils as prng +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +tf.config.set_visible_devices([], 'GPU') + +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' +# (nico) +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +flags.DEFINE_enum( + 'framework', + None, + enum_values=['jax', 'pytorch'], + help='Whether to use Jax or Pytorch for the submission. Controls among ' + 'other things if the Jax or Numpy RNG library is used for RNG.') + +FLAGS = flags.FLAGS +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + + DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" BATCH_SIZE = 2 -SEED = 42 # Fixed random seed for reproducibility - -tf_seed = SEED - -# Load the dataset -ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, -) +RNG_SEED = 1996 # Fixed random seed for reproducibility + + +def main(_): + profiler = PassThroughProfiler() + if FLAGS.framework == 'pytorch': + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + + rng = prng.PRNGKey(RNG_SEED) + data_rng, _, _, _ = prng.split(rng, 4) + + print(f"data_rng = {data_rng}") + + # Load the dataset + ds = get_lm_dataset( + data_rng=data_rng, + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, + ) + # Check if `ds` acts as a generator + if hasattr(ds, '__iter__'): + print("Dataset is an iterable/generator.") + + # Fetch first batch + try: + first_batch = next(iter(ds)) + print(f"Successfully retrieved first batch.") + except Exception as e: + print(f"Error retrieving first batch: {e}") + return + + # Print structure of a batch + print(f"First batch keys: {first_batch.keys()}") + print(f"First batch shapes:") + for key, value in first_batch.items(): + print(f" - {key}: {value.shape} (dtype: {value.dtype})") + + # Validate batch dimensions + assert "inputs" in first_batch and "targets" in first_batch, "Missing expected keys!" + assert first_batch["inputs"].shape[0] == BATCH_SIZE, "Batch size mismatch!" + assert first_batch["inputs"].shape == first_batch["targets"].shape, "Inputs and targets should have the same shape!" + + print(f"Dataset is correctly batched and structured.") + print(f"Test completed successfully.") + +if __name__ == '__main__': + flags.mark_flag_as_required('framework') + app.run(main) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index efe923dbe..14dd24545 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -76,13 +76,21 @@ normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer +from datasets import lm_preprocess +import datasets as hf_datasets +# from datasets import load_dataset, Dataset +from transformers import AutoTokenizer + +import math import functools +import itertools import os import shutil import subprocess import tarfile +from typing import Dict, List, Any from absl import app from absl import flags from absl import logging @@ -126,6 +134,9 @@ flags.DEFINE_boolean('librispeech', False, 'If --all=false, whether or not to download LibriSpeech.') +flags.DEFINE_boolean('finewebedu', + False, + 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('mnist', False, 'If --all=false, whether or not to download MNIST.') @@ -699,6 +710,86 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) +def download_finewebedu(data_dir, tmp_dir): + """Download FineWebEdu-10B.""" + + # data_dir = "/fast/najroldi/data" + + tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") + data_dir = os.path.join(data_dir, 'finewebedu') + + _maybe_mkdir(tmp_dir) + _maybe_mkdir(data_dir) + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + # cache_dir=tmp_dir + ) + + ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size + + seq_len = 2048 + max_seq_length = seq_len+1 + map_setup = dict(batched=True, batch_size=1024, num_proc=8) + + # Tokenize + tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of tokenizer = {len(tokenizer)}") + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq + add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] + return tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False + ) + + tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + tokenized_dataset = ds.map( + tokenize, + remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', + 'language_score', 'token_count', 'score', 'int_score'], + **map_setup + ) + tokenizer.model_max_length = seq_len + + # Concat in chunks of max_seq_len + def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + """Concatenate text and generate chunks of max_seq_length""" + concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + if total_length >= max_seq_length: + total_length = (total_length // max_seq_length) * max_seq_length + result = { + k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] + for k, t in concatenated_examples.items() + } + return result + + lm_dataset = tokenized_dataset.map( + concat_chunck, + **map_setup + ) + + n_tokens = len(lm_dataset) * max_seq_length + logging.info(f"Number of tokens in dataset: {n_tokens:_}") + + # Split dataset into training and validation sets + # TODO: avoid (single doc) contamination between train and val + VAL_TOKENS = 10_000_000 + val_samples = VAL_TOKENS // max_seq_length + 1 + val_dataset = lm_dataset.select(range(val_samples)) + train_dataset = lm_dataset.select(range(val_samples, len(lm_dataset))) + logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length :_}") + logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length :_}") + + # Save datasets + train_dataset.save_to_disk(os.path.join(data_dir, f"train")) + val_dataset.save_to_disk(os.path.join(data_dir, f"val")) + + def main(_): data_dir = FLAGS.data_dir tmp_dir = FLAGS.temp_dir @@ -781,6 +872,11 @@ def main(_): logging.info('Downloading WMT...') download_wmt(data_dir) + if FLAGS.all or FLAGS.finewebedu: + if not FLAGS.skip_download: + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu(data_dir) + # pylint: enable=logging-format-interpolation # pylint: enable=consider-using-with diff --git a/datasets/lm_preprocess.py b/datasets/lm_preprocess.py new file mode 100644 index 000000000..e69de29bb From ca83ab8954a9e164dc538cb4749847812ee0e032 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 14 Mar 2025 11:31:08 +0100 Subject: [PATCH 25/68] testing --- algoperf/workloads/lm/{ => dev}/test_01.py | 0 .../lm/{ => dev}/test_input_pipeline.py | 0 algoperf/workloads/lm/input_pipeline.py | 37 +++++---- .../workloads/lm/lm_jax/__init__.py | 0 algoperf/workloads/lm/lm_jax/workload.py | 20 +++++ algoperf/workloads/lm/lm_pytorch/workload.py | 56 ++++++++++++- algoperf/workloads/lm/test.py | 37 +++++++++ algoperf/workloads/lm/workload.py | 80 ++++++++++++++----- datasets/dataset_setup.py | 25 ++++-- 9 files changed, 211 insertions(+), 44 deletions(-) rename algoperf/workloads/lm/{ => dev}/test_01.py (100%) rename algoperf/workloads/lm/{ => dev}/test_input_pipeline.py (100%) rename datasets/lm_preprocess.py => algoperf/workloads/lm/lm_jax/__init__.py (100%) create mode 100644 algoperf/workloads/lm/lm_jax/workload.py create mode 100644 algoperf/workloads/lm/test.py diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/dev/test_01.py similarity index 100% rename from algoperf/workloads/lm/test_01.py rename to algoperf/workloads/lm/dev/test_01.py diff --git a/algoperf/workloads/lm/test_input_pipeline.py b/algoperf/workloads/lm/dev/test_input_pipeline.py similarity index 100% rename from algoperf/workloads/lm/test_input_pipeline.py rename to algoperf/workloads/lm/dev/test_input_pipeline.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index a14cebeda..f0024e4a6 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -15,6 +15,10 @@ RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). +# This ensures that only the primary process (RANK == 0) uses TensorFlow's +# automatic optimization (AUTOTUNE), while other processes disable it (None). +# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine the optimal +# number of elements to prefetch or parallelize for dataset operations, improving performance. AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None @@ -30,34 +34,36 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, """Load HF dataset and return a TF dataset.""" dataset_path = os.path.join(data_dir, split) - dataset = load_from_disk(dataset_path) # Loads HF arrow dataset + dataset = load_from_disk(dataset_path) is_training = split == "train" shuffle = split in ['train', 'eval_train'] + dataset.set_format("tensorflow") # tf.int64 + def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: - input_ids = example["input_ids"].numpy().astype(np.int32) # torch tensor TODO: remove numpy conversion yield { - "inputs": tf.convert_to_tensor(input_ids[:-1], dtype=tf.int32), - "targets": tf.convert_to_tensor(input_ids[1:], dtype=tf.int32), + "inputs": example["input_ids"][:-1], + "targets": example["input_ids"][1:], } - # Create a TensorFlow dataset from the generator function + # Create a TensorFlow dataset ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), - } - ) + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + } + ) # Avoid creating too many threads when using PyTorch DDP. - if RANK != 0: + # Limits TensorFlow's threading for non-primary processes (RANK != 0) + if RANK != 0: options = tf.data.Options() - options.threading.private_threadpool_size = 1 - ds = ds.with_options(options) + options.threading.private_threadpool_size = 1 # restrict dataset operations to a single thread + ds = ds.with_options(options) # apply threading restrictions if shuffle: ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) @@ -66,6 +72,9 @@ def tf_generator(): ds = ds.repeat() # Batch the dataset, ensuring the last batch is dropped if not full during training + # i.e. it groups consecutive elements into fixed-size chunks. + # Instead of processing individual elements, the dataset yields batches (tensors with multiple elements), + # improving efficiency and parallelism in training ds = ds.batch(global_batch_size, drop_remainder=is_training) ds = ds.prefetch(AUTOTUNE) diff --git a/datasets/lm_preprocess.py b/algoperf/workloads/lm/lm_jax/__init__.py similarity index 100% rename from datasets/lm_preprocess.py rename to algoperf/workloads/lm/lm_jax/__init__.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py new file mode 100644 index 000000000..4cdb42409 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -0,0 +1,20 @@ +"""LM workload implemented in Jax.""" + +import functools +from typing import Dict, Optional, Tuple + +from flax import jax_utils +import jax +import jax.numpy as jnp +import numpy as np + +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload + + +class LmWorkload(BaseLmWorkload): + + @property + def eval_batch_size(self) -> int: + return 131_072 diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 904657b1d..9ee21ccb6 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -29,8 +29,58 @@ def init_model_fn(): def model_fn(): pass - def _build_input_queue(): - pass - + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + per_device_batch_size = int(global_batch_size / N_GPUS) + + # Only create and iterate over tf input pipeline in one Python process to + # avoid creating too many threads. + if RANK == 0: + np_iter = super()._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + while True: + if RANK == 0: + batch = next(np_iter) + inputs = torch.as_tensor( + batch['inputs'], dtype=torch.float32, device=DEVICE) + targets = torch.as_tensor( + batch['targets'], dtype=torch.float32, device=DEVICE) + # Send batch to other devices when using DDP. + if USE_PYTORCH_DDP: + dist.broadcast(inputs, src=0) + inputs = inputs[0] # TODO: check + dist.broadcast(targets, src=0) + targets = targets[0] # TODO: check + else: + batch = {} + inputs = torch.empty((N_GPUS, per_device_batch_size, 39), + dtype=torch.float32, + device=DEVICE) + dist.broadcast(inputs, src=0) + inputs = inputs[RANK] + targets = torch.empty((N_GPUS, per_device_batch_size, 1), + dtype=torch.float32, + device=DEVICE) + dist.broadcast(targets, src=0) + targets = targets[RANK] + + batch = { + 'inputs': inputs, + 'targets': targets, + # 'weights': weights, + } + yield batch + + def eval_step(): pass diff --git a/algoperf/workloads/lm/test.py b/algoperf/workloads/lm/test.py new file mode 100644 index 000000000..7e693d0af --- /dev/null +++ b/algoperf/workloads/lm/test.py @@ -0,0 +1,37 @@ +""" +Test data pipaline in JAX and PyTorch. + +Instantiate a workload and loops over the input queue. +""" + +import jax +import numpy as np +import torch + +import algoperf.workloads.lm.lm_jax.workload as lm_jax +# import algoperf.workloads.lm.lm_pytorch.workload as lm_pytorch + + +data_rng = jax.random.PRNGKey(0) +split = 'train' +data_dir = "/fast/najroldi/data/finewebedu" +global_batch_size = 8 +num_batches = 10 +repeat_final_dataset = False + +# ------------------------------------------------------------------------------ +# JAX +# ------------------------------------------------------------------------------ + +# 1 GPU +workload = lm_jax.LmWorkload() + +input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + +next(input_queue) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index d070cabec..63d2c707e 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -32,7 +32,7 @@ def _build_input_queue(self, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): is_training = split == 'train' - ds, self._tokenizer = input_pipeline.get_lm_dataset( + ds = input_pipeline.get_lm_dataset( data_rng, split, data_dir, @@ -41,26 +41,66 @@ def _build_input_queue(self, global_batch_size=global_batch_size, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) - + for batch in iter(ds): yield batch - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: - """Run a full evaluation of the model.""" - - def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the loss function at (label_batch, logits_batch).""" + def _eval_model_on_split(): + pass + + def eval_period_time_sec(): + pass + + def has_reached_test_target(): + pass + + def has_reached_validation_target(): + pass + + def init_model_fn(): + pass + + def is_output_params(): + pass + + def loss_fn(): + pass + + def loss_type(): + pass + + def max_allowed_runtime_sec(): + pass + + def model_fn(): + pass + + def num_eval_train_examples(): + pass + + def num_test_examples(): + pass + + def num_train_examples(): + pass + + def num_validation_examples(): + pass + + def step_hint(): + pass + + def test_target_value(): + pass + + def train_mean(): + pass + + def train_stddev(): + pass + + def validation_target_value(): + pass + + def target_metric_name(): pass \ No newline at end of file diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 14dd24545..aab793832 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -76,10 +76,8 @@ normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer -from datasets import lm_preprocess import datasets as hf_datasets -# from datasets import load_dataset, Dataset from transformers import AutoTokenizer import math @@ -721,6 +719,9 @@ def download_finewebedu(data_dir, tmp_dir): _maybe_mkdir(tmp_dir) _maybe_mkdir(data_dir) + # Use local disk instead of NFS for temp storage + os.environ["TMPDIR"] = tmp_dir + ds = hf_datasets.load_dataset( 'HuggingFaceFW/fineweb-edu', name='sample-10BT', @@ -745,7 +746,6 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return_special_tokens_mask=False, return_attention_mask=False ) - tokenizer.model_max_length = 1e30 # prevent truncation during tokenization tokenized_dataset = ds.map( tokenize, @@ -754,8 +754,21 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: **map_setup ) tokenizer.model_max_length = seq_len + + tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) + from datasets import load_from_disk + tokenized_dataset = load_from_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) # Concat in chunks of max_seq_len + # TODO: this might take to much memory + # TODO: bug fix: Python's shutil.rmtree tried to delete a .nfs* file, but it was still in use (OSError: [Errno 16] Device or resource busy + # TODO: bug fix: I am losing tokens in the concat-chunk: num_tokens before split: 9_944_182_212 + # (1) loss happening because of batched=True: potentially losing the last tokens in the last batch of the 1024 batched examples + # NOTE: the current approach leads to data loss at batch boundaries, + # but concatenation *cannot* happen if batched=False, + # because concat_chunck relies on processing multiple examples at once. + # (2) loss happening because of nproc>1: potentially losing the last tokens in each process + # TODO: this does not allow to later change the seq_len... not a problem in AlgoPerf, but bad in plainLM def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -767,13 +780,11 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map( - concat_chunck, + concat_chunck,\ **map_setup ) - - n_tokens = len(lm_dataset) * max_seq_length + n_tokens = len(lm_dataset) * max_seq_length # 9_944_182_212 logging.info(f"Number of tokens in dataset: {n_tokens:_}") # Split dataset into training and validation sets From e3e78dc6443c5485af64bfe986951f72d9754f99 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 17 Mar 2025 11:18:41 +0100 Subject: [PATCH 26/68] LM workload tested torch pipeline --- algoperf/data_utils.py | 2 +- .../lm/dev/test_build_input_queue_torch.py | 80 +++++++++++++++++++ .../workloads/lm/{test.py => dev/test_jax.py} | 19 ++++- algoperf/workloads/lm/input_pipeline.py | 3 +- algoperf/workloads/lm/lm_jax/workload.py | 5 +- algoperf/workloads/lm/lm_pytorch/workload.py | 68 +++++++++------- algoperf/workloads/lm/workload.py | 7 +- submission_runner.py | 2 +- 8 files changed, 146 insertions(+), 40 deletions(-) create mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_torch.py rename algoperf/workloads/lm/{test.py => dev/test_jax.py} (63%) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..068c21c03 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree.map(_prepare, batch) + return jax.tree_util.tree_map(_prepare, batch) def pad(tensor: np.ndarray, diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py new file mode 100644 index 000000000..86b1ca6b7 --- /dev/null +++ b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py @@ -0,0 +1,80 @@ + +import jax +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + +n_gpus = max(N_GPUS, jax.local_device_count()) + +def sync_ddp(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def test_dataloader_torch(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = torch.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + # batch = next(input_queue) + + print(f"RANK {RANK} of {N_GPUS}") + sync_ddp() + + # Start test. + for _ in range(100): + + batch = next(input_queue) + assert type(batch) == dict + + assert 'inputs' in batch + assert 'targets' in batch + + assert type(batch['inputs']) == torch.Tensor + assert type(batch['targets']) == torch.Tensor + + assert batch['inputs'].dtype == dtype + assert batch['targets'].dtype == dtype + + assert batch['inputs'].shape == (local_batch_size, seq_len) + assert batch['targets'].shape == (local_batch_size, seq_len) + + sync_ddp() + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + test_dataloader_torch() + + +if __name__ == '__main__': + main() + diff --git a/algoperf/workloads/lm/test.py b/algoperf/workloads/lm/dev/test_jax.py similarity index 63% rename from algoperf/workloads/lm/test.py rename to algoperf/workloads/lm/dev/test_jax.py index 7e693d0af..4ba3de631 100644 --- a/algoperf/workloads/lm/test.py +++ b/algoperf/workloads/lm/dev/test_jax.py @@ -15,6 +15,7 @@ data_rng = jax.random.PRNGKey(0) split = 'train' data_dir = "/fast/najroldi/data/finewebedu" +seq_len = 2048 global_batch_size = 8 num_batches = 10 repeat_final_dataset = False @@ -34,4 +35,20 @@ num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) -next(input_queue) +batch = next(input_queue) +assert type(batch) == dict + +assert 'inputs' in batch +assert 'targets' in batch + +assert type(batch['inputs']) == np.ndarray +assert type(batch['targets']) == np.ndarray + +assert batch['inputs'].dtype == np.int64 +assert batch['targets'].dtype == np.int64 + +assert batch['inputs'].shape == (1, global_batch_size, seq_len) +assert batch['targets'].shape == (1, global_batch_size, seq_len) + +print(f"JAX devices = {jax.devices()}") +print("1") diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index f0024e4a6..e74490a16 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -25,7 +25,6 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, - is_training: bool, vocab_size: int, global_batch_size: int, num_batches: Optional[int] = None, @@ -39,7 +38,7 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, is_training = split == "train" shuffle = split in ['train', 'eval_train'] - dataset.set_format("tensorflow") # tf.int64 + dataset.set_format("tensorflow") # tf.int64 # TODO: is this needed? def tf_generator(): """Generates data in a TensorFlow-friendly format.""" diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 4cdb42409..773f8c54c 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -14,7 +14,4 @@ class LmWorkload(BaseLmWorkload): - - @property - def eval_batch_size(self) -> int: - return 131_072 + pass diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 9ee21ccb6..0ff7884c7 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,7 +1,7 @@ """LM workload implemented in PyTorch.""" import contextlib -from typing import Any, Dict, Optional, Tuple +from typing import Dict, Iterator, Optional, Tuple from absl import logging import jax @@ -22,12 +22,6 @@ class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" - - def init_model_fn(): - pass - - def model_fn(): - pass def _build_input_queue(self, data_rng: jax.random.PRNGKey, @@ -35,8 +29,12 @@ def _build_input_queue(self, data_dir: str, global_batch_size: int, num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) + + seq_len = 2048 # TODO: define it somewehere else + DTYPE = torch.int32 # TODO: decide between int32 and int64. # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. @@ -48,36 +46,50 @@ def _build_input_queue(self, global_batch_size=global_batch_size, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) + weights = None + while True: + # Only iterate over tf input pipeline in one Python process to + # avoid creating too many threads. if RANK == 0: - batch = next(np_iter) - inputs = torch.as_tensor( - batch['inputs'], dtype=torch.float32, device=DEVICE) - targets = torch.as_tensor( - batch['targets'], dtype=torch.float32, device=DEVICE) + batch = next(np_iter) # pylint: disable=stop-iteration-return + inputs = torch.as_tensor(batch['inputs'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + targets = torch.as_tensor(batch['targets'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: - dist.broadcast(inputs, src=0) - inputs = inputs[0] # TODO: check - dist.broadcast(targets, src=0) - targets = targets[0] # TODO: check + if not_train: + # During eval, the batch size of the remainder might be different. + per_device_batch_size = torch.tensor(len(targets[0]), dtype=DTYPE, device=DEVICE) + dist.broadcast(per_device_batch_size, src=0) + # We don't broadcast the shard for RANK 0. + dist.broadcast(inputs[1:], src=0) + dist.broadcast(targets[1:], src=0) + + # RANK 0 extracts his shard. If not DDP, this just flattens. + inputs, targets = inputs[0], targets[0] + else: - batch = {} - inputs = torch.empty((N_GPUS, per_device_batch_size, 39), - dtype=torch.float32, - device=DEVICE) + # Receive batch from rank 0. + if not_train: + # During eval, the batch size of the remainder might be different. + per_device_batch_size = torch.empty((1,), dtype=DTYPE, device=DEVICE) + dist.broadcast(per_device_batch_size, src=0) + + # N_GPUS - 1 since we don't broadcast the shard for RANK 0. + inputs = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) + targets = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) dist.broadcast(inputs, src=0) - inputs = inputs[RANK] - targets = torch.empty((N_GPUS, per_device_batch_size, 1), - dtype=torch.float32, - device=DEVICE) dist.broadcast(targets, src=0) - targets = targets[RANK] - + # RANK - 1 since we don't broadcast the shard for RANK 0. + inputs, targets = inputs[RANK-1], targets[RANK-1] + + if weights is None: + weights = torch.ones(per_device_batch_size, device=DEVICE) batch = { 'inputs': inputs, 'targets': targets, - # 'weights': weights, + 'weights': weights, } yield batch diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 63d2c707e..7b1313dd7 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -31,12 +31,10 @@ def _build_input_queue(self, global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): - is_training = split == 'train' ds = input_pipeline.get_lm_dataset( data_rng, split, data_dir, - is_training=is_training, vocab_size=self._vocab_size, global_batch_size=global_batch_size, num_batches=num_batches, @@ -103,4 +101,7 @@ def validation_target_value(): pass def target_metric_name(): - pass \ No newline at end of file + pass + + def eval_batch_size(): + pass diff --git a/submission_runner.py b/submission_runner.py index a2521e77b..6fac50d99 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -234,7 +234,7 @@ def train_once( dropout_rate = hyperparameters.dropout_rate if hasattr(hyperparameters, 'aux_dropout_rate'): aux_dropout_rate = hyperparameters.aux_dropout_rate - model_params, model_state = workload.init_model_fn( + model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ From e6194950fc524793906127f09b330a8329ad079f Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 17 Mar 2025 11:34:10 +0100 Subject: [PATCH 27/68] LM workload - fix torch tests --- .../lm/dev/test_build_input_queue_torch.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py index 86b1ca6b7..66205d091 100644 --- a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py @@ -41,30 +41,33 @@ def test_dataloader_torch(): data_dir=data_dir, global_batch_size=global_batch_size) - # batch = next(input_queue) - print(f"RANK {RANK} of {N_GPUS}") sync_ddp() # Start test. for _ in range(100): - + batch = next(input_queue) - assert type(batch) == dict + assert type(batch) == dict assert 'inputs' in batch assert 'targets' in batch - assert type(batch['inputs']) == torch.Tensor - assert type(batch['targets']) == torch.Tensor + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype - assert batch['inputs'].dtype == dtype - assert batch['targets'].dtype == dtype + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) - assert batch['inputs'].shape == (local_batch_size, seq_len) - assert batch['targets'].shape == (local_batch_size, seq_len) - - sync_ddp() + assert torch.equal(inputs[:,1:], targets[:,:-1]) print(f"=== ALL TEST PASSED ===") From d8e9c56738de817e561e79cffee638ab7197eaed Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:44:36 +0100 Subject: [PATCH 28/68] add LM tests, remove dev files --- algoperf/workloads/lm/dev/data_pytorch.py | 42 --------- algoperf/workloads/lm/dev/test_01.py | 92 ------------------- .../lm/dev/test_build_input_queue_torch.py | 83 ----------------- .../workloads/lm/dev/test_input_pipeline.py | 68 -------------- algoperf/workloads/lm/dev/test_jax.py | 54 ----------- 5 files changed, 339 deletions(-) delete mode 100644 algoperf/workloads/lm/dev/data_pytorch.py delete mode 100644 algoperf/workloads/lm/dev/test_01.py delete mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_torch.py delete mode 100644 algoperf/workloads/lm/dev/test_input_pipeline.py delete mode 100644 algoperf/workloads/lm/dev/test_jax.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py deleted file mode 100644 index d0081a75d..000000000 --- a/algoperf/workloads/lm/dev/data_pytorch.py +++ /dev/null @@ -1,42 +0,0 @@ - -import torch - -from datasets import Dataset, load_from_disk -from torch.utils.data import DataLoader - -trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" -vocab_size = 50280 -seq_len = 2048 -sampler = 'sequential' -sampler_seed = None -num_workers = 4 - -train_set = load_from_disk(trainset_path) # - -""" ->>> type(train_set) - - ->>> len(train_set) -7501407 - ->>> train_set[0] -{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} - ->>> type(train_set[0]['input_ids']) - - -# In PyTorch we do: -trainloader = DataLoader( - train_set, - sampler = ..., - batch_size = ..., - num_workers = ..., - pin_memory = ..., - ) - -# PyTorch’s DataLoader expects an iterable dataset, -# which means it calls __getitem__() and __len__() on train_set. - -""" - diff --git a/algoperf/workloads/lm/dev/test_01.py b/algoperf/workloads/lm/dev/test_01.py deleted file mode 100644 index 977fae11a..000000000 --- a/algoperf/workloads/lm/dev/test_01.py +++ /dev/null @@ -1,92 +0,0 @@ - -import os -import numpy as np -import tensorflow as tf -import torch - -from datasets import load_from_disk - -from absl import app -from absl import flags -from absl import logging - -from algoperf.profiler import PassThroughProfiler -from algoperf import random_utils as prng -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.input_pipeline import get_lm_dataset - - -tf.config.set_visible_devices([], 'GPU') - -# Environment variables -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -# disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' -# (nico) -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' - -flags.DEFINE_enum( - 'framework', - None, - enum_values=['jax', 'pytorch'], - help='Whether to use Jax or Pytorch for the submission. Controls among ' - 'other things if the Jax or Numpy RNG library is used for RNG.') - -FLAGS = flags.FLAGS -USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() - - -DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" -BATCH_SIZE = 2 -RNG_SEED = 1996 # Fixed random seed for reproducibility - - -def main(_): - profiler = PassThroughProfiler() - if FLAGS.framework == 'pytorch': - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) - - rng = prng.PRNGKey(RNG_SEED) - data_rng, _, _, _ = prng.split(rng, 4) - - print(f"data_rng = {data_rng}") - - # Load the dataset - ds = get_lm_dataset( - data_rng=data_rng, - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, - ) - # Check if `ds` acts as a generator - if hasattr(ds, '__iter__'): - print("Dataset is an iterable/generator.") - - # Fetch first batch - try: - first_batch = next(iter(ds)) - print(f"Successfully retrieved first batch.") - except Exception as e: - print(f"Error retrieving first batch: {e}") - return - - # Print structure of a batch - print(f"First batch keys: {first_batch.keys()}") - print(f"First batch shapes:") - for key, value in first_batch.items(): - print(f" - {key}: {value.shape} (dtype: {value.dtype})") - - # Validate batch dimensions - assert "inputs" in first_batch and "targets" in first_batch, "Missing expected keys!" - assert first_batch["inputs"].shape[0] == BATCH_SIZE, "Batch size mismatch!" - assert first_batch["inputs"].shape == first_batch["targets"].shape, "Inputs and targets should have the same shape!" - - print(f"Dataset is correctly batched and structured.") - print(f"Test completed successfully.") - -if __name__ == '__main__': - flags.mark_flag_as_required('framework') - app.run(main) diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py deleted file mode 100644 index 66205d091..000000000 --- a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py +++ /dev/null @@ -1,83 +0,0 @@ - -import jax -import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec -from algoperf.profiler import PassThroughProfiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload - -USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() - -n_gpus = max(N_GPUS, jax.local_device_count()) - -def sync_ddp(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - - -def test_dataloader_torch(): - # Test config. - rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' - split = 'train' - global_batch_size = 8 - dtype = torch.int32 - seq_len = 2048 - - local_batch_size = global_batch_size // N_GPUS - - workload = LmWorkload() - - data_rng = jax.random.PRNGKey(rng_seed) - - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - print(f"RANK {RANK} of {N_GPUS}") - sync_ddp() - - # Start test. - for _ in range(100): - - batch = next(input_queue) - - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - - assert type(inputs) == torch.Tensor - assert type(targets) == torch.Tensor - - assert inputs.device == DEVICE - assert targets.device == DEVICE - - assert inputs.dtype == dtype - assert targets.dtype == dtype - - assert inputs.shape == (local_batch_size, seq_len) - assert targets.shape == (local_batch_size, seq_len) - - assert torch.equal(inputs[:,1:], targets[:,:-1]) - - print(f"=== ALL TEST PASSED ===") - - -def main(): - profiler = PassThroughProfiler() - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) - test_dataloader_torch() - - -if __name__ == '__main__': - main() - diff --git a/algoperf/workloads/lm/dev/test_input_pipeline.py b/algoperf/workloads/lm/dev/test_input_pipeline.py deleted file mode 100644 index 47c11969f..000000000 --- a/algoperf/workloads/lm/dev/test_input_pipeline.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import tensorflow as tf -import torch -from datasets import load_from_disk - -from algoperf.workloads.lm.input_pipeline import get_lm_dataset - -DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" -BATCH_SIZE = 2 -SEED = 42 # Fixed random seed for reproducibility - - -def test_tf_dataset(): - """Tests if get_lm_dataset correctly loads the HF dataset as a TensorFlow dataset.""" - - print(f"Loading dataset from: {DATASET_PATH}") - - tf_seed = SEED - - # Load the dataset - ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, - ) - - print("Testing TensorFlow Dataset Output...") - for batch in ds.take(2): # Take two batches to test - print("Inputs:", batch["inputs"].numpy()) # Convert to NumPy for inspection - print("Targets:", batch["targets"].numpy()) - -def test_pytorch_dataloader(): - """Tests if the TensorFlow dataset can be converted to PyTorch format correctly.""" - - # Use the same TensorFlow-compatible seed - tf_seed = tf.constant(SEED, dtype=tf.int64) - - # Load the dataset - ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, - global_batch_size=BATCH_SIZE, - ) - - def _input_queue_generator(): - """Generator that converts TF dataset batches to PyTorch tensors.""" - for batch in iter(ds): - batch = {k: torch.tensor(v.numpy()) for k, v in batch.items()} # Convert to PyTorch tensors - yield batch - - dataloader = _input_queue_generator() - - print("\nTesting PyTorch DataLoader Output...") - for _ in range(2): # Take two batches - batch = next(dataloader) - print("Inputs:", batch["inputs"]) - print("Targets:", batch["targets"]) - -# Run tests -if __name__ == "__main__": - test_tf_dataset() - test_pytorch_dataloader() \ No newline at end of file diff --git a/algoperf/workloads/lm/dev/test_jax.py b/algoperf/workloads/lm/dev/test_jax.py deleted file mode 100644 index 4ba3de631..000000000 --- a/algoperf/workloads/lm/dev/test_jax.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Test data pipaline in JAX and PyTorch. - -Instantiate a workload and loops over the input queue. -""" - -import jax -import numpy as np -import torch - -import algoperf.workloads.lm.lm_jax.workload as lm_jax -# import algoperf.workloads.lm.lm_pytorch.workload as lm_pytorch - - -data_rng = jax.random.PRNGKey(0) -split = 'train' -data_dir = "/fast/najroldi/data/finewebedu" -seq_len = 2048 -global_batch_size = 8 -num_batches = 10 -repeat_final_dataset = False - -# ------------------------------------------------------------------------------ -# JAX -# ------------------------------------------------------------------------------ - -# 1 GPU -workload = lm_jax.LmWorkload() - -input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) - -batch = next(input_queue) -assert type(batch) == dict - -assert 'inputs' in batch -assert 'targets' in batch - -assert type(batch['inputs']) == np.ndarray -assert type(batch['targets']) == np.ndarray - -assert batch['inputs'].dtype == np.int64 -assert batch['targets'].dtype == np.int64 - -assert batch['inputs'].shape == (1, global_batch_size, seq_len) -assert batch['targets'].shape == (1, global_batch_size, seq_len) - -print(f"JAX devices = {jax.devices()}") -print("1") From 6b4ff12356c5f41b01ce703801b556a11079d354 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:44:58 +0100 Subject: [PATCH 29/68] add LM tests, remove dev files --- algoperf/workloads/lm/dev/data_pytorch.py | 42 ++++++ .../lm/dev/test_build_input_queue_jax.py | 127 ++++++++++++++++++ .../lm/tests/test_build_input_queue_torch.py | 87 ++++++++++++ 3 files changed, 256 insertions(+) create mode 100644 algoperf/workloads/lm/dev/data_pytorch.py create mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_jax.py create mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_torch.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py new file mode 100644 index 000000000..d0081a75d --- /dev/null +++ b/algoperf/workloads/lm/dev/data_pytorch.py @@ -0,0 +1,42 @@ + +import torch + +from datasets import Dataset, load_from_disk +from torch.utils.data import DataLoader + +trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" +vocab_size = 50280 +seq_len = 2048 +sampler = 'sequential' +sampler_seed = None +num_workers = 4 + +train_set = load_from_disk(trainset_path) # + +""" +>>> type(train_set) + + +>>> len(train_set) +7501407 + +>>> train_set[0] +{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} + +>>> type(train_set[0]['input_ids']) + + +# In PyTorch we do: +trainloader = DataLoader( + train_set, + sampler = ..., + batch_size = ..., + num_workers = ..., + pin_memory = ..., + ) + +# PyTorch’s DataLoader expects an iterable dataset, +# which means it calls __getitem__() and __len__() on train_set. + +""" + diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py new file mode 100644 index 000000000..08354be74 --- /dev/null +++ b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py @@ -0,0 +1,127 @@ + +# TODO: redo with pmap!! + +import os +import jax +import tensorflow as tf +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_jax.workload import LmWorkload + +# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make +# it unavailable to JAX. +tf.config.set_visible_devices([], 'GPU') + +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' + + +N_GPUS = jax.local_device_count() + +print(f"jax.local_devices() = {jax.local_devices()}") +print(f"jax.local_device_count() = {jax.local_device_count()}") + +print(f"N_GPUS = {N_GPUS}") + +def check_batch(batch): + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) + + assert torch.equal(inputs[:,1:], targets[:,:-1]) + + +def process_shard(batch): + inputs, targets = batch['inputs'], batch['targets'] + jax.debug.print("Processing on GPU with inputs: {shape}", shape=inputs.shape) + jax.debug.print("inputs {inputs}", inputs=inputs) + jax.debug.callback(check_batch, batch) + return inputs, targets + +# Apply process_batch across devices, sharding batch across devices +pmap_process = jax.pmap(process_shard, axis_name='batch') + + +def test_dataloader_jax(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = np.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + batch = next(input_queue) + + inputs, targets = batch['inputs'], batch['targets'] + print(f"Processing on GPU with inputs: {inputs.shape}") + + inputs, targets = pmap_process(batch) + print(f"Processing on GPU with inputs: {inputs.shape}") + print(f"Processing on GPU with inputs: {inputs}") + + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + # print(f"inputs[0]: {inputs[0]}") + # print(f"inputs[1]: {inputs[1]}") + + # for device_id in range(2): + # # Access the sharded data for each GPU + # print(inputs.shape) + # device_inputs = inputs[device_id] + # print(f" GPU {device_id} Inputs: {device_inputs.shape}") + + # @jax.pmap + # def process_batch(batch): + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + + # return inputs, targets + + # inputs, targets = batch['inputs'], batch['targets'] #process_batch(batch) + # print(f"inputs: {inputs[0]}") + + + +def main(): + test_dataloader_jax() + + +if __name__ == '__main__': + main() + diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py new file mode 100644 index 000000000..83a18ec15 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -0,0 +1,87 @@ + +import jax +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + + +def sync_ddp(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def test_dataloader_torch(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = torch.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + print(f"RANK {RANK} of {N_GPUS}") + sync_ddp() + + # batch = next(input_queue) + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + # print(f"inputs: {inputs}") + + # Start test. + for _ in range(100): + + batch = next(input_queue) + + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) + + assert torch.equal(inputs[:,1:], targets[:,:-1]) + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + test_dataloader_torch() + + +if __name__ == '__main__': + main() + From 3c5c847eb1489fa11a65c98c0f3327bd3c23c088 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:45:41 +0100 Subject: [PATCH 30/68] Stop tracking .gitignore --- .gitignore | 28 ---------------------------- 1 file changed, 28 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 7d35f0ccc..000000000 --- a/.gitignore +++ /dev/null @@ -1,28 +0,0 @@ -__pycache__/* -__pycache__ -*egg-info -*eggs -.vscode/ -env/ -venv/ -workdir/ -makefile -*.out -*.sh -*.swp -*/data/ -*events.out.tfevents* -algoperf/workloads/librispeech_conformer/data_dir -algoperf/workloads/librispeech_conformer/work_dir -*.flac -*.npy -*.csv -*.vocab -wandb/ -*.txt -scoring/plots/ - -!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv -!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv - -algoperf/_version.py From 20d841b1932408bc905051dc2e188f3a43e0d749 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:47:55 +0100 Subject: [PATCH 31/68] Remove dev/ from repo, keep locally --- algoperf/workloads/lm/dev/data_pytorch.py | 42 ------ .../lm/dev/test_build_input_queue_jax.py | 127 ------------------ 2 files changed, 169 deletions(-) delete mode 100644 algoperf/workloads/lm/dev/data_pytorch.py delete mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_jax.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py deleted file mode 100644 index d0081a75d..000000000 --- a/algoperf/workloads/lm/dev/data_pytorch.py +++ /dev/null @@ -1,42 +0,0 @@ - -import torch - -from datasets import Dataset, load_from_disk -from torch.utils.data import DataLoader - -trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" -vocab_size = 50280 -seq_len = 2048 -sampler = 'sequential' -sampler_seed = None -num_workers = 4 - -train_set = load_from_disk(trainset_path) # - -""" ->>> type(train_set) - - ->>> len(train_set) -7501407 - ->>> train_set[0] -{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} - ->>> type(train_set[0]['input_ids']) - - -# In PyTorch we do: -trainloader = DataLoader( - train_set, - sampler = ..., - batch_size = ..., - num_workers = ..., - pin_memory = ..., - ) - -# PyTorch’s DataLoader expects an iterable dataset, -# which means it calls __getitem__() and __len__() on train_set. - -""" - diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py deleted file mode 100644 index 08354be74..000000000 --- a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py +++ /dev/null @@ -1,127 +0,0 @@ - -# TODO: redo with pmap!! - -import os -import jax -import tensorflow as tf -import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec -from algoperf.profiler import PassThroughProfiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.lm_jax.workload import LmWorkload - -# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make -# it unavailable to JAX. -tf.config.set_visible_devices([], 'GPU') - -# Environment variables -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -# disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' - - -N_GPUS = jax.local_device_count() - -print(f"jax.local_devices() = {jax.local_devices()}") -print(f"jax.local_device_count() = {jax.local_device_count()}") - -print(f"N_GPUS = {N_GPUS}") - -def check_batch(batch): - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - - assert type(inputs) == torch.Tensor - assert type(targets) == torch.Tensor - - assert inputs.device == DEVICE - assert targets.device == DEVICE - - assert inputs.dtype == dtype - assert targets.dtype == dtype - - assert inputs.shape == (local_batch_size, seq_len) - assert targets.shape == (local_batch_size, seq_len) - - assert torch.equal(inputs[:,1:], targets[:,:-1]) - - -def process_shard(batch): - inputs, targets = batch['inputs'], batch['targets'] - jax.debug.print("Processing on GPU with inputs: {shape}", shape=inputs.shape) - jax.debug.print("inputs {inputs}", inputs=inputs) - jax.debug.callback(check_batch, batch) - return inputs, targets - -# Apply process_batch across devices, sharding batch across devices -pmap_process = jax.pmap(process_shard, axis_name='batch') - - -def test_dataloader_jax(): - # Test config. - rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' - split = 'train' - global_batch_size = 8 - dtype = np.int32 - seq_len = 2048 - - local_batch_size = global_batch_size // N_GPUS - - workload = LmWorkload() - - data_rng = jax.random.PRNGKey(rng_seed) - - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - batch = next(input_queue) - - inputs, targets = batch['inputs'], batch['targets'] - print(f"Processing on GPU with inputs: {inputs.shape}") - - inputs, targets = pmap_process(batch) - print(f"Processing on GPU with inputs: {inputs.shape}") - print(f"Processing on GPU with inputs: {inputs}") - - # inputs, targets = batch['inputs'], batch['targets'] - # print(f"inputs.shape: {inputs.shape}") - # print(f"inputs[0]: {inputs[0]}") - # print(f"inputs[1]: {inputs[1]}") - - # for device_id in range(2): - # # Access the sharded data for each GPU - # print(inputs.shape) - # device_inputs = inputs[device_id] - # print(f" GPU {device_id} Inputs: {device_inputs.shape}") - - # @jax.pmap - # def process_batch(batch): - # inputs, targets = batch['inputs'], batch['targets'] - # print(f"inputs.shape: {inputs.shape}") - - # return inputs, targets - - # inputs, targets = batch['inputs'], batch['targets'] #process_batch(batch) - # print(f"inputs: {inputs[0]}") - - - -def main(): - test_dataloader_jax() - - -if __name__ == '__main__': - main() - From f3ba0593d955c657b6da8a07eede425509dbc6b9 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 10:00:44 +0100 Subject: [PATCH 32/68] fix comments --- algoperf/workloads/lm/input_pipeline.py | 2 +- datasets/dataset_setup.py | 27 +++++++------------------ 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index e74490a16..bae1f5e45 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -38,7 +38,7 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, is_training = split == "train" shuffle = split in ['train', 'eval_train'] - dataset.set_format("tensorflow") # tf.int64 # TODO: is this needed? + dataset.set_format("tensorflow") # tf.int64 # TODO (nico): is this needed? def tf_generator(): """Generates data in a TensorFlow-friendly format.""" diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index aab793832..8299133c1 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -711,8 +711,6 @@ def download_wmt(data_dir): def download_finewebedu(data_dir, tmp_dir): """Download FineWebEdu-10B.""" - # data_dir = "/fast/najroldi/data" - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") data_dir = os.path.join(data_dir, 'finewebedu') @@ -726,7 +724,7 @@ def download_finewebedu(data_dir, tmp_dir): 'HuggingFaceFW/fineweb-edu', name='sample-10BT', split='train', - # cache_dir=tmp_dir + cache_dir=tmp_dir ) ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size @@ -756,19 +754,11 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - from datasets import load_from_disk - tokenized_dataset = load_from_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) # Concat in chunks of max_seq_len - # TODO: this might take to much memory - # TODO: bug fix: Python's shutil.rmtree tried to delete a .nfs* file, but it was still in use (OSError: [Errno 16] Device or resource busy - # TODO: bug fix: I am losing tokens in the concat-chunk: num_tokens before split: 9_944_182_212 - # (1) loss happening because of batched=True: potentially losing the last tokens in the last batch of the 1024 batched examples - # NOTE: the current approach leads to data loss at batch boundaries, - # but concatenation *cannot* happen if batched=False, - # because concat_chunck relies on processing multiple examples at once. - # (2) loss happening because of nproc>1: potentially losing the last tokens in each process - # TODO: this does not allow to later change the seq_len... not a problem in AlgoPerf, but bad in plainLM + # TODO (nico): this might take to much memory + # TODO (nico): bug fix: Python's shutil.rmtree tried to delete .nfs file, but it was still in use (OSError: [Errno 16] Device or resource busy + # TODO (nico): make it sequential or increase batch_size in the map_setup def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -780,15 +770,12 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map( - concat_chunck,\ - **map_setup - ) - n_tokens = len(lm_dataset) * max_seq_length # 9_944_182_212 + lm_dataset = tokenized_dataset.map(concat_chunck, **map_setup) + n_tokens = len(lm_dataset) * max_seq_length logging.info(f"Number of tokens in dataset: {n_tokens:_}") # Split dataset into training and validation sets - # TODO: avoid (single doc) contamination between train and val + # TODO (nico): avoid (single doc) contamination, by splitting before concatenation VAL_TOKENS = 10_000_000 val_samples = VAL_TOKENS // max_seq_length + 1 val_dataset = lm_dataset.select(range(val_samples)) From 381451f04a34e4a78a5256f92e1e7c092e0eadeb Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 10:46:45 +0100 Subject: [PATCH 33/68] add class specifications --- algoperf/workloads/lm/lm_jax/workload.py | 36 +++- algoperf/workloads/lm/lm_pytorch/workload.py | 26 ++- algoperf/workloads/lm/workload.py | 201 +++++++++++++------ datasets/dataset_setup.py | 6 +- 4 files changed, 199 insertions(+), 70 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 773f8c54c..84377b4bc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,17 +1,47 @@ """LM workload implemented in Jax.""" import functools -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple +from absl import logging from flax import jax_utils +from flax import linen as nn +from flax.training import common_utils import jax import jax.numpy as jnp import numpy as np +import optax from algoperf import param_utils +from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload - class LmWorkload(BaseLmWorkload): - pass + """LM JAX workload.""" + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """aux_dropout_rate is used as attention_dropout_rate.""" + pass + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + pass + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + pass diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 0ff7884c7..404dc2532 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -23,6 +23,24 @@ class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """aux_dropout_rate is used as attention_dropout_rate.""" + pass + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + pass + def _build_input_queue(self, data_rng: jax.random.PRNGKey, split: str, @@ -93,6 +111,10 @@ def _build_input_queue(self, } yield batch - - def eval_step(): + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" pass diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 7b1313dd7..e36d54625 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -5,6 +5,9 @@ import os from typing import Any, Dict, Optional, Tuple +from absl import flags +import torch.distributed as dist + import jax import numpy as np import torch @@ -12,17 +15,98 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +FLAGS = flags.FLAGS + USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ class BaseLmWorkload(spec.Workload): - """A LM workload.""" + """LM workload.""" _vocab_size: int = 32000 def __init__(self) -> None: super().__init__() - self._tokenizer = None + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return 'ppl' + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result['validation/ppl'] > self.validation_target_value + + @property + def validation_target_value(self) -> float: + pass + + def has_reached_test_target(self, eval_result: float) -> bool: + return eval_result['test/ppl'] > self.test_target_value + + @property + def test_target_value(self) -> float: + pass + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + pass + + @property + def num_eval_train_examples(self) -> int: + pass + + @property + def num_validation_examples(self) -> int: + pass + + @property + def num_test_examples(self) -> int: + pass + + @property + def eval_batch_size(self) -> int: + pass + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + pass + + @property + def eval_period_time_sec(self) -> int: + pass + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + pass + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return 'silu' + + @property + def glu(self) -> bool: + return True def _build_input_queue(self, data_rng: jax.random.PRNGKey, @@ -43,65 +127,58 @@ def _build_input_queue(self, for batch in iter(ds): yield batch - def _eval_model_on_split(): - pass - - def eval_period_time_sec(): - pass - - def has_reached_test_target(): - pass - - def has_reached_validation_target(): - pass - - def init_model_fn(): - pass - - def is_output_params(): - pass - - def loss_fn(): - pass - - def loss_type(): - pass - - def max_allowed_runtime_sec(): - pass - - def model_fn(): - pass - - def num_eval_train_examples(): - pass - - def num_test_examples(): - pass - - def num_train_examples(): - pass - - def num_validation_examples(): - pass - - def step_hint(): - pass - - def test_target_value(): - pass - - def train_mean(): - pass - - def train_stddev(): - pass - - def validation_target_value(): - pass - - def target_metric_name(): - pass + @abc.abstractmethod + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True) + + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + loss += self._eval_batch(params, eval_batch) + if USE_PYTORCH_DDP: + dist.all_reduce(loss) + mean_loss = loss.item() / num_examples + return {'loss': mean_loss} - def eval_batch_size(): + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the (masked) loss function at (label_batch, logits_batch). + + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ pass + + + diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 8299133c1..fb8701f4d 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -711,11 +711,11 @@ def download_wmt(data_dir): def download_finewebedu(data_dir, tmp_dir): """Download FineWebEdu-10B.""" - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") data_dir = os.path.join(data_dir, 'finewebedu') - - _maybe_mkdir(tmp_dir) + tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None \ + else os.path.expanduser("~/.cache/huggingface/datasets") _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) # Use local disk instead of NFS for temp storage os.environ["TMPDIR"] = tmp_dir From f111d2e8baada7af619504a87974fa78f3e34d55 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 11:29:37 +0100 Subject: [PATCH 34/68] add workload LM info --- algoperf/workloads/workloads.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4712f4e25..6b99a25a6 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -114,6 +114,7 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', }, + 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' }, @@ -150,6 +151,7 @@ 'imagenet_vit', 'librispeech_conformer', 'librispeech_deepspeech', + 'lm', 'ogbg', 'wmt' ] From 808d398ee2cf78e92cea29e2d0696eb6ce592929 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 11:32:48 +0100 Subject: [PATCH 35/68] restore data_utils.py tree map --- algoperf/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 068c21c03..37d1bd20f 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree_util.tree_map(_prepare, batch) + return jax.tree.map(_prepare, batch) def pad(tensor: np.ndarray, From 35f8f8942cb993628f1b20c3d29346e4d7b40e95 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 14:38:41 +0100 Subject: [PATCH 36/68] fixed NFS bug --- datasets/dataset_setup.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index fb8701f4d..a68da3ff5 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -708,26 +708,28 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) -def download_finewebedu(data_dir, tmp_dir): +def download_finewebedu(data_dir, tmp_dir=None): """Download FineWebEdu-10B.""" data_dir = os.path.join(data_dir, 'finewebedu') - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None \ - else os.path.expanduser("~/.cache/huggingface/datasets") + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser('~/.cache/huggingface/datasets') + _maybe_mkdir(data_dir) _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) - # Use local disk instead of NFS for temp storage os.environ["TMPDIR"] = tmp_dir ds = hf_datasets.load_dataset( 'HuggingFaceFW/fineweb-edu', name='sample-10BT', split='train', - cache_dir=tmp_dir + cache_dir=cache_dir ) - ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size + # Shuffle so that multiproc has shards of similar size. + ds = ds.shuffle(seed=1996) seq_len = 2048 max_seq_length = seq_len+1 @@ -754,11 +756,8 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - + # Concat in chunks of max_seq_len - # TODO (nico): this might take to much memory - # TODO (nico): bug fix: Python's shutil.rmtree tried to delete .nfs file, but it was still in use (OSError: [Errno 16] Device or resource busy - # TODO (nico): make it sequential or increase batch_size in the map_setup def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} From cbb6ee67c6eb4828b574987d45fde508e5f1db67 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 15:02:27 +0100 Subject: [PATCH 37/68] train/val split before concat --- datasets/dataset_setup.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index a68da3ff5..5e27211e8 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -756,8 +756,21 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - - # Concat in chunks of max_seq_len + + # Find how many entries to take from dataset to have VAL_TOKENS in validation set. + VAL_TOKENS = 10_000_000 + tokens_accumulated, num_examples_for_val = 0, 0 + for example in tokenized_dataset: + tokens_accumulated += len(example['input_ids']) + num_examples_for_val += 1 + if tokens_accumulated >= VAL_TOKENS: + break + # Split in train and valid. + val_dataset = tokenized_dataset.select(range(num_examples_for_val)) + train_dataset = tokenized_dataset.select(range(num_examples_for_val, len(tokenized_dataset))) + + # Concat in chunks of max_seq_len. + # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -769,18 +782,11 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map(concat_chunck, **map_setup) - n_tokens = len(lm_dataset) * max_seq_length - logging.info(f"Number of tokens in dataset: {n_tokens:_}") - - # Split dataset into training and validation sets - # TODO (nico): avoid (single doc) contamination, by splitting before concatenation - VAL_TOKENS = 10_000_000 - val_samples = VAL_TOKENS // max_seq_length + 1 - val_dataset = lm_dataset.select(range(val_samples)) - train_dataset = lm_dataset.select(range(val_samples, len(lm_dataset))) - logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length :_}") - logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length :_}") + # Concat text in validation and train sets. + val_dataset = val_dataset.map(concat_chunck, **map_setup) + train_dataset = train_dataset.map(concat_chunck, **map_setup) + logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") + logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}") # Save datasets train_dataset.save_to_disk(os.path.join(data_dir, f"train")) From 848b50c7251ed8330510a3eb6853d9acafb6c265 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 18 Mar 2025 19:23:04 +0000 Subject: [PATCH 38/68] reformatting --- tests/modeldiffs/wmt/compare.py | 6 +++--- tests/modeldiffs/wmt_attention_temp/compare.py | 6 +++--- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 +++--- tests/modeldiffs/wmt_post_ln/compare.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 656d07fb5..64401ef7f 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -75,9 +75,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index d30474f00..01dc2895c 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index 4b0a0b218..77e71c826 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 818eec672..909fcd672 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) From fb62eae40f2a5cb00c27118c26423256ebe5cd8e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 18 Mar 2025 19:26:05 +0000 Subject: [PATCH 39/68] reformatting --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index eace6054d..a5c59000b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -392,8 +392,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From fe3f9f021c7f621fea099746951a85d45e550df1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 18 Mar 2025 19:26:45 +0000 Subject: [PATCH 40/68] reformatting --- .../paper_baselines/adamw/jax/submission.py | 20 ++++++------- .../paper_baselines/nadamw/jax/submission.py | 10 +++---- .../paper_baselines/sam/jax/submission.py | 8 +++--- .../shampoo/jax/distributed_shampoo.py | 28 ++++++++++++------- .../target_setting_algorithms/jax_nadamw.py | 10 +++---- 5 files changed, 40 insertions(+), 36 deletions(-) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index d381d3dfd..5bc0644a8 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -55,15 +55,16 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): return optimizer_state, opt_update_fn + def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -163,8 +164,7 @@ def update_params( static_argnums=(0, 1), donate_argnums=(2, 3, 4), in_shardings=arg_shardings, - out_shardings=out_shardings - ) + out_shardings=out_shardings) new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index ade3e9f63..c451a18ac 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index ce2d1feef..b76589705 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -67,9 +67,8 @@ def update_fn(updates, state, grad_fn_params_tuple): # the noised parameters in the same order as on the original gradients and # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, - params, - updates) + noised_params = jax.tree_util.tree_map( + lambda p, u: p + rho * u, params, updates) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), @@ -81,7 +80,8 @@ def update_fn(updates, state, grad_fn_params_tuple): sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) scaled_updates = jax.tree.map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, + updates = jax.lax.cond(updates_norm > grad_clip, + lambda _: scaled_updates, lambda _: updates, None) updates, state = base_opt_update_fn(updates, state, params) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 4f670a85b..a5c2732ac 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1809,13 +1809,17 @@ def sharded_update_fn(grads, state, params): )) new_stats_flat = jax.tree.map( - lambda g, s, p: _compute_stats(g, s, p, state.count), + lambda g, + s, + p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) outputs = jax.tree.map( - lambda g, s, p: _transform_grad(g, s, p, state.count), + lambda g, + s, + p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) @@ -1919,8 +1923,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), errors - >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), + errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + @@ -2438,7 +2442,9 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree.map( - lambda g, s, p: _compute_stats(g, s, p, state.count), + lambda g, + s, + p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) @@ -2447,7 +2453,9 @@ def update_fn(grads, state, params): params_flat, state.count) outputs = jax.tree.map( - lambda g, s, p: _transform_grad(g, s, p, state.count), + lambda g, + s, + p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index fbb7a612c..1ba56bbda 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -108,9 +108,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -125,9 +124,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): From 004afbd541c710cf37ab92546119011ffc36ad28 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 18 Mar 2025 19:27:15 +0000 Subject: [PATCH 41/68] reformatting --- .../external_tuning/jax_nadamw_full_budget.py | 10 ++++------ .../external_tuning/jax_nadamw_target_setting.py | 10 ++++------ .../self_tuning/jax_nadamw_full_budget.py | 10 ++++------ .../self_tuning/jax_nadamw_target_setting.py | 10 ++++------ 4 files changed, 16 insertions(+), 24 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index ade3e9f63..c451a18ac 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index d3a27411c..b8ac10f33 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 9bc014ed1..78c3b5b3e 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index a6781612c..ffe854a0e 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): From f1db3d3e6a7f699ad4e3dcb04b11320f55d10a5c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 18 Mar 2025 19:27:51 +0000 Subject: [PATCH 42/68] reformatting --- algoperf/profiler.py | 4 +- .../criteo1tb/criteo1tb_jax/workload.py | 21 ++++----- .../workloads/fastmri/fastmri_jax/workload.py | 11 ++--- .../fastmri/fastmri_pytorch/workload.py | 4 +- .../imagenet_jax/custom_tf_addons.py | 46 +++++++++---------- .../imagenet_jax/randaugment.py | 8 ++-- .../imagenet_resnet/imagenet_jax/workload.py | 8 ++-- .../imagenet_pytorch/workload.py | 4 +- .../librispeech_jax/models.py | 10 ++-- .../librispeech_jax/spectrum_augmenter.py | 4 +- .../librispeech_jax/workload.py | 2 +- .../librispeech_pytorch/workload.py | 9 ++-- .../librispeech_jax/models.py | 41 +++++++++-------- algoperf/workloads/mnist/workload.py | 7 ++- algoperf/workloads/wmt/wmt_pytorch/models.py | 4 +- 15 files changed, 95 insertions(+), 88 deletions(-) diff --git a/algoperf/profiler.py b/algoperf/profiler.py index d73efd964..fa2a1bee2 100644 --- a/algoperf/profiler.py +++ b/algoperf/profiler.py @@ -72,8 +72,8 @@ def _make_report( float(np.std(d)), len(d), float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) - for a, d in self.recorded_durations.items()] + 100.0 * float(np.sum(d)) / total_duration) for a, + d in self.recorded_durations.items()] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 2f41eb8c6..30ecd2b00 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -130,16 +130,16 @@ def model_fn( return logits_batch, None @functools.partial( - jax.jit, - in_shardings=(sharding_utils.get_replicated_sharding(), - sharding_utils.get_naive_sharding_spec(), - ), - static_argnums=(0,), - out_shardings=sharding_utils.get_replicated_sharding() - ) + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding()) def _eval_batch_jitted(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor]) -> spec.Tensor: logits, _ = self.model_fn( params, batch, @@ -160,8 +160,7 @@ def _eval_batch(self, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return np.array( - self._eval_batch_jitted(params, batch), dtype=np.float64) + return np.array(self._eval_batch_jitted(params, batch), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index cbe961a09..2709ef1e6 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -96,14 +96,11 @@ def loss_fn( @functools.partial( jax.jit, - in_shardings=( - sharding_utils.get_replicated_sharding(), - sharding_utils.get_naive_sharding_spec(), - sharding_utils.get_replicated_sharding() - ), + in_shardings=(sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + sharding_utils.get_replicated_sharding()), static_argnums=(0,), - out_shardings=sharding_utils.get_replicated_sharding() - ) + out_shardings=sharding_utils.get_replicated_sharding()) def _eval_model(self, params: spec.Tensor, batch: Dict[str, spec.Tensor], diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 216a033d4..58943de2f 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -250,7 +250,9 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index c9bf154bb..3d6939218 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -20,31 +20,27 @@ tf.dtypes.float64, } -Number = Union[ - float, - int, - np.float16, - np.float32, - np.float64, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, -] - -TensorLike = Union[ - List[Union[Number, list]], - tuple, - Number, - np.ndarray, - tf.Tensor, - tf.SparseTensor, - tf.Variable, -] +Number = Union[float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64,] + +TensorLike = Union[List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable,] def get_ndims(image): diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 2d7e873c2..c68e2de33 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -316,7 +316,8 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), lambda: im, + tf.equal(step, 0), + lambda: im, lambda: tf.gather(build_lut(histo, step), im)) return tf.cast(result, tf.uint8) @@ -551,6 +552,7 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): translate_const=100) image = tf.cond( tf.equal(i, op_to_select), - lambda selected_func=func, selected_args=args: selected_func( - image, *selected_args), lambda: image) + lambda selected_func=func, + selected_args=args: selected_func(image, *selected_args), + lambda: image) return image diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 6a7cf03be..87b9b82bd 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -105,12 +105,12 @@ def init_model_fn( self._param_types = param_utils.jax_param_types(self._param_shapes) mesh = sharding_utils.get_mesh() params = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) - ), + lambda x: jax.device_put(x, + sharding_utils.get_replicated_sharding(mesh)), params) model_state = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) - ), + lambda x: jax.device_put(x, + sharding_utils.get_replicated_sharding(mesh)), model_state) return params, model_state diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index bd07edee1..ed29271f3 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -307,7 +307,9 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 2bb527a36..593d463c3 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,10 +442,12 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', lambda s: jnp.zeros(s, dtype), + 'mean', + lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', lambda s: jnp.ones(s, dtype), + 'var', + lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index c16740629..2a6f73d4d 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights - < multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights < + multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 8e3f3d975..d1bf29ba4 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -446,7 +446,7 @@ def use_gelu(self) -> bool: @property def validation_target_value(self) -> float: return 0.094114 - + @property def test_target_value(self) -> float: return 0.056629 diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 1dae389d7..5ed37957e 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -259,9 +259,8 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], - device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( - -1, 1) + fin_result.shape[1], device=result.device).view( + 1, -1) < result.count_nonzero(dim=1).view(-1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -329,7 +328,9 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index e1d76730d..0646aac52 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -144,8 +144,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -278,10 +278,12 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', lambda s: jnp.zeros(s, dtype), + 'mean', + lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', lambda s: jnp.ones(s, dtype), + 'var', + lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) @@ -393,7 +395,7 @@ def __call__( c_0_shape = np.shape(c_0) weights_shape = np.shape(weights) seq_lengths_np = np.shape(seq_lengths) - + n = jax.devices() logging.info(f"jax num devices {n}") logging.info(f'inputs shape {inputs_shape}') @@ -440,6 +442,7 @@ def unpack_weights( self.bidirectional, ) + ### Swap in regular LSTM layer for debuggin @jax.vmap def flip_sequences(inputs: Array, lengths: Array) -> Array: @@ -474,6 +477,7 @@ def flip_sequences(inputs: Array, lengths: Array) -> Array: idxs = (jnp.arange(max_length - 1, -1, -1) + lengths) % max_length return inputs[idxs] + class GenericRNNSequenceEncoder(nn.Module): """Encodes a single sequence using any RNN cell, for example `nn.LSTMCell`. @@ -503,8 +507,11 @@ def setup(self): in_axes=(1, flax.core.axes_scan.broadcast, flax.core.axes_scan.broadcast), out_axes=1, split_rngs={'params': False}) - def unroll_cell(self, cell_state: StateType, inputs: Array, - recurrent_dropout_mask: Optional[Array], deterministic: bool): + def unroll_cell(self, + cell_state: StateType, + inputs: Array, + recurrent_dropout_mask: Optional[Array], + deterministic: bool): """Unrolls a recurrent cell over an input sequence. Args: @@ -554,7 +561,8 @@ def __call__(self, inputs = flip_sequences(inputs, lengths) recurrent_dropout_mask = None - _, (cell_states, outputs) = self.unroll_cell(initial_state, inputs, + _, (cell_states, outputs) = self.unroll_cell(initial_state, + inputs, recurrent_dropout_mask, deterministic) final_state = jax.tree.map( @@ -602,8 +610,7 @@ def __call__( inputs: Array, lengths: Array, initial_states: Optional[Sequence[StateType]] = None, - deterministic: bool = False - ) -> Tuple[Array, Sequence[StateType]]: + deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: """Processes the input sequence using the recurrent cell. Args: @@ -635,15 +642,12 @@ def __call__( rng = jax.random.PRNGKey(0) initial_states = [ self.cell_type(self.hidden_size).initialize_carry( - rng, (batch_size, 1) - ) - for _ in range(num_cells) + rng, (batch_size, 1)) for _ in range(num_cells) ] if len(initial_states) != num_cells: raise ValueError( f'Please provide {self.num_cells} (`num_layers`, *2 if bidirectional)' - 'initial states.' - ) + 'initial states.') # For each layer, apply the forward and optionally the backward RNN cell. cell_idx = 0 @@ -756,6 +760,7 @@ def __call__( initial_states=initial_states, deterministic=deterministic) + class BatchRNN(nn.Module): """Implements a single deepspeech encoder layer. """ @@ -775,10 +780,9 @@ def __call__(self, inputs, input_paddings, train): input_paddings, train) - # For regular LSTM + # For regular LSTM hidden_size = ( - config.encoder_dim // 2 if config.bidirectional else config.encoder_dim - ) + config.encoder_dim // 2 if config.bidirectional else config.encoder_dim) lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32) # output, _ = LSTM( @@ -791,7 +795,6 @@ def __call__(self, inputs, input_paddings, train): features=config.encoder_dim // 2, bidirectional=config.bidirectional, num_layers=1)(inputs, input_paddings) - return output diff --git a/algoperf/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py index c92dd141b..f53aadd0b 100644 --- a/algoperf/workloads/mnist/workload.py +++ b/algoperf/workloads/mnist/workload.py @@ -46,7 +46,8 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'],}) + 'targets': x['label'], + }) is_train = split == 'train' if cache: @@ -213,6 +214,8 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py index 089f1bfbb..a1c7ce15e 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models.py @@ -942,8 +942,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) - >= cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) From 868987c2fd72ced8107048e20de44a7e303074e8 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 19 Mar 2025 09:41:05 +0100 Subject: [PATCH 43/68] renamed datasets to avoid conflict with HF --- {datasets => datasets_algoperf}/README.md | 0 .../dataset_setup.py | 17 ++++++++++------- .../librispeech_preprocess.py | 2 +- .../librispeech_tokenizer.py | 0 4 files changed, 11 insertions(+), 8 deletions(-) rename {datasets => datasets_algoperf}/README.md (100%) rename {datasets => datasets_algoperf}/dataset_setup.py (98%) rename {datasets => datasets_algoperf}/librispeech_preprocess.py (98%) rename {datasets => datasets_algoperf}/librispeech_tokenizer.py (100%) diff --git a/datasets/README.md b/datasets_algoperf/README.md similarity index 100% rename from datasets/README.md rename to datasets_algoperf/README.md diff --git a/datasets/dataset_setup.py b/datasets_algoperf/dataset_setup.py similarity index 98% rename from datasets/dataset_setup.py rename to datasets_algoperf/dataset_setup.py index 5e27211e8..21811e729 100644 --- a/datasets/dataset_setup.py +++ b/datasets_algoperf/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets/dataset_setup.py \ +python3 datasets_algoperf/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -126,15 +126,15 @@ flags.DEFINE_boolean('fastmri', False, 'If --all=false, whether or not to download FastMRI.') +flags.DEFINE_boolean('finewebedu', + False, + 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('imagenet', False, 'If --all=false, whether or not to download Imagenet.') flags.DEFINE_boolean('librispeech', False, 'If --all=false, whether or not to download LibriSpeech.') -flags.DEFINE_boolean('finewebedu', - False, - 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('mnist', False, 'If --all=false, whether or not to download MNIST.') @@ -727,6 +727,8 @@ def download_finewebedu(data_dir, tmp_dir=None): split='train', cache_dir=cache_dir ) + # TODO (nico): maybe save intermediate dataset to avoid re-downloading + # and allow re-chunking with different seq_len? # Shuffle so that multiproc has shards of similar size. ds = ds.shuffle(seed=1996) @@ -747,6 +749,7 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return_attention_mask=False ) tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + logging.info(f"Tokenizing...") tokenized_dataset = ds.map( tokenize, remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', @@ -783,6 +786,7 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: } return result # Concat text in validation and train sets. + logging.info(f"Concatenating and chunking...") val_dataset = val_dataset.map(concat_chunck, **map_setup) train_dataset = train_dataset.map(concat_chunck, **map_setup) logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") @@ -876,9 +880,8 @@ def main(_): download_wmt(data_dir) if FLAGS.all or FLAGS.finewebedu: - if not FLAGS.skip_download: - logging.info('Downloading FineWebEdu-10B...') - download_finewebedu(data_dir) + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu(data_dir, tmp_dir) # pylint: enable=logging-format-interpolation diff --git a/datasets/librispeech_preprocess.py b/datasets_algoperf/librispeech_preprocess.py similarity index 98% rename from datasets/librispeech_preprocess.py rename to datasets_algoperf/librispeech_preprocess.py index a8c5cae1d..cd291e5b3 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets_algoperf/librispeech_preprocess.py @@ -15,7 +15,7 @@ from pydub import AudioSegment import tensorflow as tf -from datasets import librispeech_tokenizer +from datasets_algoperf import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets/librispeech_tokenizer.py b/datasets_algoperf/librispeech_tokenizer.py similarity index 100% rename from datasets/librispeech_tokenizer.py rename to datasets_algoperf/librispeech_tokenizer.py From dd59dedc97f99e994221775b1e980d845bfb908c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 19 Mar 2025 09:55:11 +0100 Subject: [PATCH 44/68] renamed datasets to dataset --- {datasets_algoperf => dataset}/README.md | 0 {datasets_algoperf => dataset}/dataset_setup.py | 6 +++--- {datasets_algoperf => dataset}/librispeech_preprocess.py | 2 +- {datasets_algoperf => dataset}/librispeech_tokenizer.py | 0 4 files changed, 4 insertions(+), 4 deletions(-) rename {datasets_algoperf => dataset}/README.md (100%) rename {datasets_algoperf => dataset}/dataset_setup.py (99%) rename {datasets_algoperf => dataset}/librispeech_preprocess.py (98%) rename {datasets_algoperf => dataset}/librispeech_tokenizer.py (100%) diff --git a/datasets_algoperf/README.md b/dataset/README.md similarity index 100% rename from datasets_algoperf/README.md rename to dataset/README.md diff --git a/datasets_algoperf/dataset_setup.py b/dataset/dataset_setup.py similarity index 99% rename from datasets_algoperf/dataset_setup.py rename to dataset/dataset_setup.py index 21811e729..0c7f33de6 100644 --- a/datasets_algoperf/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets_algoperf/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -74,8 +74,8 @@ from algoperf.workloads.wmt import tokenizer from algoperf.workloads.wmt.input_pipeline import \ normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer +from dataset import librispeech_preprocess +from dataset import librispeech_tokenizer import datasets as hf_datasets from transformers import AutoTokenizer diff --git a/datasets_algoperf/librispeech_preprocess.py b/dataset/librispeech_preprocess.py similarity index 98% rename from datasets_algoperf/librispeech_preprocess.py rename to dataset/librispeech_preprocess.py index cd291e5b3..b96881332 100644 --- a/datasets_algoperf/librispeech_preprocess.py +++ b/dataset/librispeech_preprocess.py @@ -15,7 +15,7 @@ from pydub import AudioSegment import tensorflow as tf -from datasets_algoperf import librispeech_tokenizer +from dataset import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets_algoperf/librispeech_tokenizer.py b/dataset/librispeech_tokenizer.py similarity index 100% rename from datasets_algoperf/librispeech_tokenizer.py rename to dataset/librispeech_tokenizer.py From c208cc7a760b5883c4127c79d457e0d352fd873b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 19 Mar 2025 19:59:51 +0000 Subject: [PATCH 45/68] sharding deepspeech --- .../librispeech_jax/models.py | 18 +++++++++--------- .../librispeech_jax/workload.py | 17 +++++++++++++++++ pyproject.toml | 8 ++------ .../paper_baselines/adamw/jax/submission.py | 2 +- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 0646aac52..1cab9e89e 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -397,15 +397,15 @@ def __call__( seq_lengths_np = np.shape(seq_lengths) n = jax.devices() - logging.info(f"jax num devices {n}") - logging.info(f'inputs shape {inputs_shape}') - logging.info(f'h_0 shape {h_0_shape}') - logging.info(f'c_0 shape {c_0_shape}') - logging.info(f'seq_lengths shape {seq_lengths_np}') - logging.info(f'weights_shape {weights_shape}') - logging.info(f'input_size {input_size}') - logging.info(f'hidden_size {self.features}') - logging.info(f'num_layers {self.num_layers}') + # logging.info(f"jax num devices {n}") + # logging.info(f'inputs shape {inputs_shape}') + # logging.info(f'h_0 shape {h_0_shape}') + # logging.info(f'c_0 shape {c_0_shape}') + # logging.info(f'seq_lengths shape {seq_lengths_np}') + # logging.info(f'weights_shape {weights_shape}') + # logging.info(f'input_size {input_size}') + # logging.info(f'hidden_size {self.features}') + # logging.info(f'num_layers {self.num_layers}') y, h, c = rnn.lstm( x=inputs, h_0=h_0, c_0=c_0, weights=weights, diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 392de7b89..c29e952cb 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -4,6 +4,8 @@ from flax import jax_utils import jax import jax.numpy as jnp +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P import numpy as np from algoperf import param_utils @@ -66,6 +68,21 @@ def model_fn( update_batch_norm: bool, use_running_average_bn: Optional[bool] = None ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + + model_fn_sharded = shard_map(model_fn_ref, + self.mesh, + ) + + def model_fn_ref( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN diff --git a/pyproject.toml b/pyproject.toml index cc404f4b5..f4ebdaee3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,15 +106,11 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", + "jax", "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", - "jax-cuda12-plugin[with_cuda]==0.4.28", - "jax-cuda12-pjrt==0.4.28", + "jax[cuda12]", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 5bc0644a8..8dc71950e 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -74,7 +74,7 @@ def _loss_fn(params): model_state, spec.ForwardPassMode.TRAIN, rng, - update_batch_norm=True) + update_batch_norm=True,) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, From 2e4cc9e9184a371f92af1989bb9a4ac8648d71ce Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 19 Mar 2025 23:15:15 +0000 Subject: [PATCH 46/68] ogbg jit migration --- algoperf/workloads/ogbg/input_pipeline.py | 9 +++- algoperf/workloads/ogbg/ogbg_jax/models.py | 4 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 24 ++++++++--- algoperf/workloads/ogbg/workload.py | 1 + .../paper_baselines/adamw/jax/submission.py | 43 +++++++++---------- submission_runner.py | 5 ++- 6 files changed, 53 insertions(+), 33 deletions(-) diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py index 3cb6f51de..9506a5343 100644 --- a/algoperf/workloads/ogbg/input_pipeline.py +++ b/algoperf/workloads/ogbg/input_pipeline.py @@ -148,10 +148,15 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): weights_shards.append(weights) if count == num_shards: + # yield { + # 'inputs': jraph.batch(graphs_shards), + # 'targets': np.vstack(labels_shards), + # 'weights': np.vstack(weights_shards) + # } def f(x): - return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) - + return jax.tree.map(lambda *vals: np.concatenate(vals, axis=0), x[0], *x[1:]) + graphs_shards = f(graphs_shards) labels_shards = f(labels_shards) weights_shards = f(weights_shards) diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 0e66d2ab8..9607a14e1 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -2,6 +2,7 @@ # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. from typing import Optional, Tuple +import jax from flax import linen as nn import jax.numpy as jnp import jraph @@ -78,7 +79,8 @@ def __call__(self, graph, train): self.hidden_dims, dropout=dropout, activation_fn=activation_fn), update_global_fn=_make_mlp( self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) - + # jax.debug.print(str(graph)) + graph = net(graph) # Map globals to represent the final result diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e895d15a7..347f89721 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -8,6 +8,7 @@ import jraph import optax +from algoperf import sharding_utils from algoperf import param_utils from algoperf import spec from algoperf.workloads.ogbg import metrics @@ -45,7 +46,8 @@ def init_model_fn( params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(params), None + params = sharding_utils.shard_replicated(params) + return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_17' @@ -106,11 +108,20 @@ def _eval_metric(self, labels, logits, masks): return metrics.EvalMetrics.single_from_model_output( loss=loss['per_example'], logits=logits, labels=labels, mask=masks) + # @functools.partial( + # jax.pmap, + # axis_name='batch', + # in_axes=(None, 0, 0, 0, None), + # static_broadcasted_argnums=(0,)) @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + sharding_utils.get_replicated_sharding(), + sharding_utils.get_replicated_sharding()), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding(), + ) def _eval_batch(self, params, batch, model_state, rng): return super()._eval_batch(params, batch, model_state, rng) @@ -119,7 +130,8 @@ def _normalize_eval_metrics( Any]) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples - total_metrics = total_metrics.reduce() + # total_metrics = total_metrics.reduce() + print(total_metrics) return {k: float(v) for k, v in total_metrics.compute().items()} diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 971e7f0f6..14666c081 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -161,6 +161,7 @@ def _eval_batch(self, spec.ForwardPassMode.EVAL, rng, update_batch_norm=False) + jax.debug.print(str(logits)) return self._eval_metric(batch['targets'], logits, batch['weights']) def _eval_model_on_split(self, diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 8dc71950e..6c6d19ef8 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -75,6 +75,7 @@ def _loss_fn(params): spec.ForwardPassMode.TRAIN, rng, update_batch_norm=True,) + jax.debug.print("logits: {logits}", logits=logits) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, @@ -140,31 +141,29 @@ def update_params( replicated = NamedSharding(mesh, P()) # No partitioning sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension - # Define input and output shardings - arg_shardings = ( - # workload is static - # opt_update_fn is static - replicated, # model_state - replicated, # optimizer_state - replicated, # current_param_container - sharded, # batch - replicated, # rng - replicated, # grad_clip - replicated # label_smoothing - ) - out_shardings = ( - replicated, # new_optimizer_state - replicated, # updated_params - replicated, # new_model_state - replicated, # loss - replicated # grad_norm - ) jitted_train_step = jax.jit( train_step, static_argnums=(0, 1), donate_argnums=(2, 3, 4), - in_shardings=arg_shardings, - out_shardings=out_shardings) + in_shardings= ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ), + out_shardings=( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + )) + # print(batch) new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, @@ -176,7 +175,7 @@ def update_params( label_smoothing) # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: + if global_step % 1 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { 'loss': loss.item(), diff --git a/submission_runner.py b/submission_runner.py index a5c59000b..bc2c49b99 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -392,8 +392,9 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if False: + # if ((train_step_end_time - train_state['last_eval_time']) >= + # workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From 496b9c31f0bdd9a50e18a6907146969fd98e73cf Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 10:52:54 +0100 Subject: [PATCH 47/68] fix style --- .gitignore | 28 +++++++++++ algoperf/workloads/lm/input_pipeline.py | 50 ++++++++----------- algoperf/workloads/lm/lm_jax/workload.py | 15 +----- algoperf/workloads/lm/lm_pytorch/workload.py | 46 +++++++++-------- .../lm/tests/test_build_input_queue_torch.py | 18 +++---- algoperf/workloads/lm/workload.py | 12 ++--- 6 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..916a29ff4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +__pycache__/* +__pycache__ +*egg-info +*eggs +.vscode/ +env/ +venv/ +workdir/ +makefile +*.out +*.sh +*.swp +*/data/ +*events.out.tfevents* +algoperf/workloads/librispeech_conformer/data_dir +algoperf/workloads/librispeech_conformer/work_dir +*.flac +*.npy +*.csv +*.vocab +wandb/ +*.txt +scoring/plots/ + +!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv +!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv + +algoperf/_version.py \ No newline at end of file diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index bae1f5e45..53fe79276 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -1,24 +1,22 @@ """Input pipeline for a LM dataset.""" import functools import os +from typing import Optional -from datasets import Dataset, load_from_disk -from typing import Dict, List, Optional, Union - +from datasets import load_from_disk import jax -import numpy as np import tensorflow as tf -import tensorflow_datasets as tfds from algoperf import data_utils from algoperf.pytorch_utils import pytorch_setup RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). -# This ensures that only the primary process (RANK == 0) uses TensorFlow's +# This ensures that only the primary process (RANK == 0) uses TensorFlow's # automatic optimization (AUTOTUNE), while other processes disable it (None). -# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine the optimal -# number of elements to prefetch or parallelize for dataset operations, improving performance. +# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine +# the optimal number of elements to prefetch or parallelize for dataset +# operations, improving performance. AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None @@ -44,25 +42,24 @@ def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: yield { - "inputs": example["input_ids"][:-1], - "targets": example["input_ids"][1:], + "inputs": example["input_ids"][:-1], + "targets": example["input_ids"][1:], } # Create a TensorFlow dataset ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), - } - ) + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + }) # Avoid creating too many threads when using PyTorch DDP. # Limits TensorFlow's threading for non-primary processes (RANK != 0) - if RANK != 0: + if RANK != 0: options = tf.data.Options() - options.threading.private_threadpool_size = 1 # restrict dataset operations to a single thread - ds = ds.with_options(options) # apply threading restrictions + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) if shuffle: ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) @@ -70,10 +67,7 @@ def tf_generator(): if is_training: ds = ds.repeat() - # Batch the dataset, ensuring the last batch is dropped if not full during training - # i.e. it groups consecutive elements into fixed-size chunks. - # Instead of processing individual elements, the dataset yields batches (tensors with multiple elements), - # improving efficiency and parallelism in training + # Batch the dataset, grouping consecutive elements into fixed-size chunks. ds = ds.batch(global_batch_size, drop_remainder=is_training) ds = ds.prefetch(AUTOTUNE) @@ -83,9 +77,9 @@ def tf_generator(): # Shard the dataset across multiple GPUs/TPUs if necessary ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, + global_batch_size=global_batch_size), + ds) - return ds \ No newline at end of file + return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 84377b4bc..64d538dda 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,22 +1,11 @@ """LM workload implemented in Jax.""" -import functools -from typing import Any, Dict, Iterator, Optional, Tuple +from typing import Dict, Optional, Tuple -from absl import logging -from flax import jax_utils -from flax import linen as nn -from flax.training import common_utils -import jax -import jax.numpy as jnp -import numpy as np -import optax - -from algoperf import param_utils -from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload + class LmWorkload(BaseLmWorkload): """LM JAX workload.""" diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 404dc2532..e57d26390 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -3,16 +3,10 @@ import contextlib from typing import Dict, Iterator, Optional, Tuple -from absl import logging import jax -import tensorflow as tf import torch import torch.distributed as dist -from torch.nn import DataParallel as DP -import torch.nn.functional as F -from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload @@ -41,16 +35,17 @@ def model_fn( update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: pass - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) - + seq_len = 2048 # TODO: define it somewehere else DTYPE = torch.int32 # TODO: decide between int32 and int64. @@ -65,20 +60,25 @@ def _build_input_queue(self, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) weights = None - + while True: # Only iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return - inputs = torch.as_tensor(batch['inputs'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - targets = torch.as_tensor(batch['targets'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + inputs = torch.as_tensor( + batch['inputs'], dtype=DTYPE, + device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + targets = torch.as_tensor( + batch['targets'], dtype=DTYPE, + device=DEVICE) # (N_GPUS, global_batch_size, seq_len) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: if not_train: # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.tensor(len(targets[0]), dtype=DTYPE, device=DEVICE) + per_device_batch_size = torch.tensor( + len(targets[0]), dtype=DTYPE, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # We don't broadcast the shard for RANK 0. dist.broadcast(inputs[1:], src=0) @@ -95,12 +95,16 @@ def _build_input_queue(self, dist.broadcast(per_device_batch_size, src=0) # N_GPUS - 1 since we don't broadcast the shard for RANK 0. - inputs = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) - targets = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) + inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), + dtype=DTYPE, + device=DEVICE) + targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), + dtype=DTYPE, + device=DEVICE) dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) # RANK - 1 since we don't broadcast the shard for RANK 0. - inputs, targets = inputs[RANK-1], targets[RANK-1] + inputs, targets = inputs[RANK - 1], targets[RANK - 1] if weights is None: weights = torch.ones(per_device_batch_size, device=DEVICE) diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py index 83a18ec15..639e71491 100644 --- a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -1,11 +1,6 @@ - import jax import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec + from algoperf.profiler import PassThroughProfiler from algoperf.pytorch_utils import pytorch_init from algoperf.pytorch_utils import pytorch_setup @@ -29,20 +24,20 @@ def test_dataloader_torch(): seq_len = 2048 local_batch_size = global_batch_size // N_GPUS - + workload = LmWorkload() data_rng = jax.random.PRNGKey(rng_seed) - + input_queue = workload._build_input_queue( data_rng=data_rng, split=split, data_dir=data_dir, global_batch_size=global_batch_size) - + print(f"RANK {RANK} of {N_GPUS}") sync_ddp() - + # batch = next(input_queue) # inputs, targets = batch['inputs'], batch['targets'] # print(f"inputs.shape: {inputs.shape}") @@ -71,7 +66,7 @@ def test_dataloader_torch(): assert inputs.shape == (local_batch_size, seq_len) assert targets.shape == (local_batch_size, seq_len) - assert torch.equal(inputs[:,1:], targets[:,:-1]) + assert torch.equal(inputs[:, 1:], targets[:, :-1]) print(f"=== ALL TEST PASSED ===") @@ -84,4 +79,3 @@ def main(): if __name__ == '__main__': main() - diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index e36d54625..3d04be3c5 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -3,14 +3,11 @@ import abc import math import os -from typing import Any, Dict, Optional, Tuple +from typing import Dict, Optional from absl import flags -import torch.distributed as dist - import jax -import numpy as np -import torch +import torch.distributed as dist from algoperf import spec from algoperf.workloads.lm import input_pipeline @@ -155,7 +152,7 @@ def _eval_model_on_split(self, global_batch_size, num_batches, repeat_final_dataset=True) - + for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) loss += self._eval_batch(params, eval_batch) @@ -179,6 +176,3 @@ def loss_fn( (not synced across devices). """ pass - - - From 50989eb6a8a54c43225a4243f770a4419d431a81 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 10:57:06 +0100 Subject: [PATCH 48/68] fix formatting --- algoperf/workloads/lm/lm_pytorch/workload.py | 1 - submission_runner.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index e57d26390..be6c94c46 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,6 +1,5 @@ """LM workload implemented in PyTorch.""" -import contextlib from typing import Dict, Iterator, Optional, Tuple import jax diff --git a/submission_runner.py b/submission_runner.py index d7df006bb..f8a66452d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -234,7 +234,7 @@ def train_once( dropout_rate = hyperparameters.dropout_rate if hasattr(hyperparameters, 'aux_dropout_rate'): aux_dropout_rate = hyperparameters.aux_dropout_rate - model_params, model_state = workload.init_model_fn( + model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ @@ -384,8 +384,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From 5af0fdc1437d924e2e162de5100e66782d01a7e5 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:02:22 +0100 Subject: [PATCH 49/68] fix style --- algoperf/workloads/lm/lm_pytorch/workload.py | 16 ++++++++-------- algoperf/workloads/lm/workload.py | 1 + 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index be6c94c46..606f16ad7 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -45,8 +45,8 @@ def _build_input_queue( not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) - seq_len = 2048 # TODO: define it somewehere else - DTYPE = torch.int32 # TODO: decide between int32 and int64. + seq_len = self._seq_len # TODO: define it somewehere else? + dtype = torch.int32 # TODO: decide between int32 and int64. # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. @@ -66,10 +66,10 @@ def _build_input_queue( if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( - batch['inputs'], dtype=DTYPE, + batch['inputs'], dtype=dtype, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) targets = torch.as_tensor( - batch['targets'], dtype=DTYPE, + batch['targets'], dtype=dtype, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) # Send batch to other devices when using DDP. @@ -77,7 +77,7 @@ def _build_input_queue( if not_train: # During eval, the batch size of the remainder might be different. per_device_batch_size = torch.tensor( - len(targets[0]), dtype=DTYPE, device=DEVICE) + len(targets[0]), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # We don't broadcast the shard for RANK 0. dist.broadcast(inputs[1:], src=0) @@ -90,15 +90,15 @@ def _build_input_queue( # Receive batch from rank 0. if not_train: # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((1,), dtype=DTYPE, device=DEVICE) + per_device_batch_size = torch.empty((1,), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # N_GPUS - 1 since we don't broadcast the shard for RANK 0. inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=DTYPE, + dtype=dtype, device=DEVICE) targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=DTYPE, + dtype=dtype, device=DEVICE) dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 3d04be3c5..aa6d188b3 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -21,6 +21,7 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 32000 + _seq_len: int = 2048 def __init__(self) -> None: super().__init__() From 26830999b92d26c729171cae141ee7abb3409463 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:32:47 +0100 Subject: [PATCH 50/68] fix style --- algoperf/workloads/lm/workload.py | 2 +- dataset/dataset_setup.py | 91 +++++++++++++++++++------------ 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index aa6d188b3..4eb6c74a5 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -24,7 +24,7 @@ class BaseLmWorkload(spec.Workload): _seq_len: int = 2048 def __init__(self) -> None: - super().__init__() + pass @property def target_metric_name(self) -> str: diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 0c7f33de6..8f0b09ab7 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -80,7 +80,6 @@ import datasets as hf_datasets from transformers import AutoTokenizer -import math import functools import itertools import os @@ -713,7 +712,9 @@ def download_finewebedu(data_dir, tmp_dir=None): data_dir = os.path.join(data_dir, 'finewebedu') tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' - cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser('~/.cache/huggingface/datasets') + cache_dir = os.path.join(tmp_dir, + 'lm') if tmp_dir is not None else os.path.expanduser( + '~/.cache/huggingface/datasets') _maybe_mkdir(data_dir) _maybe_mkdir(tmp_dir) @@ -722,75 +723,93 @@ def download_finewebedu(data_dir, tmp_dir=None): os.environ["TMPDIR"] = tmp_dir ds = hf_datasets.load_dataset( - 'HuggingFaceFW/fineweb-edu', - name='sample-10BT', - split='train', - cache_dir=cache_dir - ) - # TODO (nico): maybe save intermediate dataset to avoid re-downloading + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir) + # TODO (nico): maybe save intermediate dataset to avoid re-downloading # and allow re-chunking with different seq_len? # Shuffle so that multiproc has shards of similar size. ds = ds.shuffle(seed=1996) seq_len = 2048 - max_seq_length = seq_len+1 + max_seq_length = seq_len + 1 map_setup = dict(batched=True, batch_size=1024, num_proc=8) # Tokenize - tokenizer = AutoTokenizer.from_pretrained('gpt2') - logging.info(f"Vocab size of tokenizer = {len(tokenizer)}") + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq + add_eos = lambda seq: (seq + lm_tokenizer.eos_token) if seq else seq add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] - return tokenizer( - add_eos_batched(examples["text"]), - return_special_tokens_mask=False, - return_attention_mask=False - ) - tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + return lm_tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False) + + lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization logging.info(f"Tokenizing...") tokenized_dataset = ds.map( - tokenize, - remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', - 'language_score', 'token_count', 'score', 'int_score'], - **map_setup - ) - tokenizer.model_max_length = seq_len - + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score' + ], + **map_setup) + lm_tokenizer.model_max_length = seq_len + tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - # Find how many entries to take from dataset to have VAL_TOKENS in validation set. - VAL_TOKENS = 10_000_000 + # Find how many entries to take from dataset to have val_tokens in validation set. + val_tokens = 10_000_000 # TODO: decide this value. tokens_accumulated, num_examples_for_val = 0, 0 for example in tokenized_dataset: tokens_accumulated += len(example['input_ids']) num_examples_for_val += 1 - if tokens_accumulated >= VAL_TOKENS: - break + if tokens_accumulated >= val_tokens: + break # Split in train and valid. val_dataset = tokenized_dataset.select(range(num_examples_for_val)) - train_dataset = tokenized_dataset.select(range(num_examples_for_val, len(tokenized_dataset))) + train_dataset = tokenized_dataset.select( + range(num_examples_for_val, len(tokenized_dataset))) # Concat in chunks of max_seq_len. # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" - concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} + concatenated_examples = { + k: list(itertools.chain(*examples[k])) for k in examples.keys() + } total_length = len(concatenated_examples[list(examples.keys())[0]]) if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length + total_length = (total_length // max_seq_length) * max_seq_length result = { - k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] - for k, t in concatenated_examples.items() + k: [ + t[i:i + max_seq_length] + for i in range(0, total_length, max_seq_length) + ] for k, t in concatenated_examples.items() } return result + # Concat text in validation and train sets. logging.info(f"Concatenating and chunking...") val_dataset = val_dataset.map(concat_chunck, **map_setup) train_dataset = train_dataset.map(concat_chunck, **map_setup) - logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") - logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}") + logging.info( + f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") + logging.info( + f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}" + ) # Save datasets train_dataset.save_to_disk(os.path.join(data_dir, f"train")) From 6b7ee29684ee9bf1f9564032f65c09373212c4a4 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:36:27 +0100 Subject: [PATCH 51/68] fix yapf --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index f8a66452d..468a04c7c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -384,8 +384,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From 46b645b2ac4a4f4b93fe4ee6324b07f412fb81b3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:38:40 +0100 Subject: [PATCH 52/68] fix style --- dataset/dataset_setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 8f0b09ab7..6587f1439 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -797,7 +797,8 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: k: [ t[i:i + max_seq_length] for i in range(0, total_length, max_seq_length) - ] for k, t in concatenated_examples.items() + ] for k, + t in concatenated_examples.items() } return result From b3ae6474be93f07c578f885bae484773b8a65515 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 27 Mar 2025 15:56:25 +0000 Subject: [PATCH 53/68] HF datasets pipeline --- algoperf/workloads/lm/input_pipeline.py | 75 ++++++++++- .../lm/tests/test_hf_input_pipeline.py | 116 ++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 algoperf/workloads/lm/tests/test_hf_input_pipeline.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 53fe79276..ea4cb9d63 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -3,12 +3,17 @@ import os from typing import Optional -from datasets import load_from_disk import jax +import jax.numpy as jnp import tensorflow as tf +import torch +import torch.nn.functional as F +from transformers import GPT2Tokenizer from algoperf import data_utils from algoperf.pytorch_utils import pytorch_setup +from datasets import load_dataset +from datasets import load_from_disk RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). @@ -20,6 +25,74 @@ AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None +def get_hf_dataloader(cache_dir: str, + data_rng: jax.random.PRNGKey, + batch_size: int = 8, + seq_len: int = 32, + framework: str = "torch", + split="train"): + """ + Create a data loader from HuggingFace's FineWeb dataset. + + Args: + cache_dir: Directory to cache the dataset + batch_size: Number of sequences per batch + seq_len: Length of each sequence + framework: Either "torch" or "jax" to specify output tensor type + split: Dataset split to load + """ + # Initialize tokenizer and get vocab size + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + # Load the FineWeb dataset in streaming mode + fw = load_dataset( + "HuggingFaceFW/fineweb-edu", + name="sample-10BT", + split=split, + streaming=True, + cache_dir=cache_dir) + fw = fw.batch(batch_size=batch_size, drop_last_batch=True) + if split in ['train', 'eval_train']: + fw = fw.shuffle(seed=int(data_rng[-1])) + + def _tokenize(x): + """Tokenize and pad text to seq_len+1 tokens.""" + if framework == "torch": + tokens = tokenizer(x, return_tensors="pt")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = F.pad(tokens, pad_length, value=tokenizer.pad_token_id) + elif framework == "jax": + tokens = tokenizer(x, return_tensors="jax")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = jnp.pad( + tokens, + pad_length, + mode="constant", + constant_values=tokenizer.pad_token_id) + return tokens[:seq_len + 1] + + def batch_iterator(): + for doc in fw: + if framework == "torch": + token_ids = torch.stack([_tokenize(x) for x in doc['text']]) + # Take first seq_len+1 tokens and convert to one-hot + tokens = F.one_hot(token_ids, num_classes=vocab_size).float() + # Split into input/target + inputs, targets = tokens[:, :-1, :], tokens[:, 1:, :] + inputs, targets = inputs.to("cuda"), targets.to("cuda") + elif framework == "jax": + token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) + tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) + inputs, targets = tokens[:, :-1], tokens[:, 1:] + devices = jax.devices("gpu") + inputs, targets = jax.device_put(inputs), jax.device_put(targets) + yield inputs, targets + + return batch_iterator() + + def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, diff --git a/algoperf/workloads/lm/tests/test_hf_input_pipeline.py b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py new file mode 100644 index 000000000..36bab0d02 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py @@ -0,0 +1,116 @@ +"""Tests for LM HuggingFace input pipeline.""" +import os + +import jax +import jax.numpy as jnp +import torch +from transformers import GPT2Tokenizer + +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + +def main(): + # Setup test environment + cache_dir = "/home/ak4605/data" + if not os.path.exists(cache_dir): + raise FileNotFoundError(f"Cache directory {cache_dir} not found") + + data_rng = jax.random.PRNGKey(42) + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + + print("Running JAX output shapes and types test...") + batch_size = 8 + seq_len = 32 + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == jnp.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == jnp.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert jnp.all(jnp.sum(inputs, axis=-1) == 1), "Inputs should be one-hot encoded" + assert jnp.all(jnp.sum(targets, axis=-1) == 1), "Targets should be one-hot encoded" + print("✓ JAX test passed") + + print("\nRunning Torch output shapes and types test...") + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="torch", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == torch.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == torch.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert torch.all(torch.sum(inputs, dim=-1) == 1), "Inputs should be one-hot encoded" + assert torch.all(torch.sum(targets, dim=-1) == 1), "Targets should be one-hot encoded" + print("✓ Torch test passed") + + print("\nTesting consistent batching with same seed...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Input batches should be identical with same seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Target batches should be identical with same seed" + print("✓ Consistent batching test passed") + + print("\nTesting eval split doesn't shuffle...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(999)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Eval inputs should be identical regardless of seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Eval targets should be identical regardless of seed" + print("✓ Eval no shuffling test passed") + + print("\nAll tests passed successfully!") + + +if __name__ == "__main__": + main() From f095d4b167dabc0e1aeb925b871f32f427fc22c8 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 27 Mar 2025 17:03:05 +0000 Subject: [PATCH 54/68] Testing with linear model --- algoperf/workloads/lm/input_pipeline.py | 1 - algoperf/workloads/lm/lm_jax/models.py | 18 +++++++++ algoperf/workloads/lm/lm_jax/workload.py | 26 +++++++++++-- algoperf/workloads/lm/lm_pytorch/models.py | 18 +++++++++ algoperf/workloads/lm/lm_pytorch/workload.py | 32 +++++++++++++-- .../workloads/lm/tests/test_linear_model.py | 39 +++++++++++++++++++ algoperf/workloads/lm/workload.py | 17 ++------ 7 files changed, 129 insertions(+), 22 deletions(-) create mode 100644 algoperf/workloads/lm/lm_jax/models.py create mode 100644 algoperf/workloads/lm/lm_pytorch/models.py create mode 100644 algoperf/workloads/lm/tests/test_linear_model.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index ea4cb9d63..cc658501e 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -86,7 +86,6 @@ def batch_iterator(): token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] - devices = jax.devices("gpu") inputs, targets = jax.device_put(inputs), jax.device_put(targets) yield inputs, targets diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py new file mode 100644 index 000000000..edfc102fa --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -0,0 +1,18 @@ +from flax import linen as nn +import jax.numpy as jnp + +class LinearModel(nn.Module): + vocab_size: int + + @nn.compact + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + x = nn.Dense( + 512, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros + )(inputs) + return nn.Dense( + self.vocab_size, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros + )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 64d538dda..30b0c7867 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,8 +2,12 @@ from typing import Dict, Optional, Tuple +import jax.numpy as jnp +from flax import jax_utils +from algoperf import param_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_jax.models import LinearModel class LmWorkload(BaseLmWorkload): @@ -14,18 +18,32 @@ def init_model_fn( rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - pass + + model = LinearModel(vocab_size=self._vocab_size) + input_shape = (1, self._seq_len, self._vocab_size) + variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) + model_state, params = variables.pop('params') + + self._param_shapes = param_utils.jax_param_shapes(params) + self._param_types = param_utils.jax_param_types(self._param_shapes) + model_state = jax_utils.replicate(model_state) + params = jax_utils.replicate(params) + + return params, model_state def model_fn( self, params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - pass + + del mode, rng, update_batch_norm # Not used for linear model + inputs = batch['inputs'] + logits = self._model.apply({'params': params, **model_state}, inputs) + return logits, model_state def _eval_batch(self, params: spec.ParameterContainer, diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/lm/lm_pytorch/models.py new file mode 100644 index 000000000..545763924 --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/models.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn + +class LinearLayer(nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.bottleneck = nn.Linear(vocab_size, 512) + self.output = nn.Linear(512, vocab_size) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.bottleneck.weight, std=0.02) + nn.init.zeros_(self.bottleneck.bias) + nn.init.normal_(self.output.weight, std=0.02) + nn.init.zeros_(self.output.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output(self.bottleneck(x)) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 606f16ad7..3395aa08f 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -5,10 +5,13 @@ import jax import torch import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_pytorch.models import LinearLayer USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -21,18 +24,39 @@ def init_model_fn( rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - pass + + if hasattr(self, '_model'): + self._model.reset_parameters() + return self._model, None + + torch.manual_seed(rng[0]) + self._model = LinearLayer(vocab_size=self._vocab_size) + self._param_shapes = param_utils.pytorch_param_shapes(self._model) + self._param_types = param_utils.pytorch_param_types(self._param_shapes) + self._model.to(DEVICE) + + if N_GPUS > 1: + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) + + return self._model, None def model_fn( self, params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - pass + + del model_state, rng, update_batch_norm # Not used for linear model + model = params + inputs = batch['inputs'].float() # Convert one-hot to float + logits = model(inputs) + return logits, None def _build_input_queue( self, diff --git a/algoperf/workloads/lm/tests/test_linear_model.py b/algoperf/workloads/lm/tests/test_linear_model.py new file mode 100644 index 000000000..31cd1d577 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_linear_model.py @@ -0,0 +1,39 @@ +import jax +import jax.numpy as jnp +import torch + +TEST_SEQ_LEN = 512 + +def test_pytorch_linear(): + from algoperf.workloads.lm.lm_pytorch.models import LinearLayer + vocab_size = 32000 + model = LinearLayer(vocab_size) + + batch_size = 8 + seq_len = TEST_SEQ_LEN + inputs = torch.randn(batch_size, seq_len, vocab_size) + outputs = model(inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not torch.isnan(outputs).any() + +def test_jax_linear(): + from algoperf.workloads.lm.lm_jax.models import LinearModel + + vocab_size = 32000 + seq_len = TEST_SEQ_LEN + batch_size = 8 + model = LinearModel(vocab_size) + rng = jax.random.PRNGKey(0) + params = model.init(rng, jnp.ones((1, seq_len, vocab_size))) + + inputs = jax.random.normal(rng, (batch_size, seq_len, vocab_size)) + outputs = model.apply(params, inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not jnp.isnan(outputs).any() + +if __name__ == '__main__': + test_pytorch_linear() + test_jax_linear() + print("All tests passed!") diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 4eb6c74a5..a06b17fdc 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -20,8 +20,8 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" - _vocab_size: int = 32000 - _seq_len: int = 2048 + _vocab_size: int = 50257 + _seq_len: int = 512 def __init__(self) -> None: pass @@ -106,6 +106,7 @@ def activation(self) -> str: def glu(self) -> bool: return True + @abc.abstractmethod def _build_input_queue(self, data_rng: jax.random.PRNGKey, split: str, @@ -113,17 +114,7 @@ def _build_input_queue(self, global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): - ds = input_pipeline.get_lm_dataset( - data_rng, - split, - data_dir, - vocab_size=self._vocab_size, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) - - for batch in iter(ds): - yield batch + """Build an input queue for the given split.""" @abc.abstractmethod def _eval_batch(self, From 0c22f3df420968cf820cbcc826f84a61751f95f5 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 12:28:05 -0400 Subject: [PATCH 55/68] lm workload with linear model --- .../workloads/cifar/cifar_jax/workload.py | 11 -- algoperf/workloads/lm/input_pipeline.py | 2 +- algoperf/workloads/lm/lm_jax/models.py | 5 +- algoperf/workloads/lm/lm_jax/workload.py | 82 +++++++++-- algoperf/workloads/lm/lm_pytorch/workload.py | 129 ++++++++++-------- algoperf/workloads/lm/workload.py | 59 ++++---- pyproject.toml | 3 +- .../nesterov/jax/submission.py | 8 +- submission_runner.py | 6 +- 9 files changed, 187 insertions(+), 118 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f827fac87..fd990eeaa 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -71,17 +71,6 @@ def _build_input_queue( cache, repeat_final_dataset) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics - # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index cc658501e..440de64c1 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -87,7 +87,7 @@ def batch_iterator(): tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield inputs, targets + yield {'inputs': inputs, 'targets': targets} return batch_iterator() diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py index edfc102fa..72ee5bd83 100644 --- a/algoperf/workloads/lm/lm_jax/models.py +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -7,12 +7,13 @@ class LinearModel(nn.Module): @nn.compact def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: x = nn.Dense( - 512, + 10, kernel_init=nn.initializers.normal(0.02), bias_init=nn.initializers.zeros )(inputs) return nn.Dense( self.vocab_size, kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros + bias_init=nn.initializers.zeros, + name="output" )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 30b0c7867..7cb50302f 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,33 +2,57 @@ from typing import Dict, Optional, Tuple +import jax import jax.numpy as jnp +import optax from flax import jax_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using HuggingFace FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="jax", + split=split) + return loader + def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - model = LinearModel(vocab_size=self._vocab_size) + self._model = LinearModel(vocab_size=self._vocab_size) input_shape = (1, self._seq_len, self._vocab_size) - variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) - model_state, params = variables.pop('params') - + params_rng, init_rng = jax.random.split(rng) + print(params_rng) + # variables = model.init(init_rng, jnp.ones(input_shape, jnp.float32)) + variables = jax.jit(self._model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32)) + params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) - + params = sharding_utils.shard_replicated(params) + model_state = None return params, model_state def model_fn( @@ -40,15 +64,51 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm # Not used for linear model + del mode, rng, update_batch_norm, model_state inputs = batch['inputs'] - logits = self._model.apply({'params': params, **model_state}, inputs) - return logits, model_state + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in JAX.""" + vocab_size = logits_batch.shape[-1] + + if len(label_batch.shape) == len(logits_batch.shape): + # One-hot labels + loss = -jnp.sum(label_batch * jax.nn.log_softmax(logits_batch, axis=-1)) + else: + # Dense labels + loss = -jax.nn.log_softmax(logits_batch)[jnp.arange(label_batch.shape[0]), label_batch] + + if mask_batch is not None: + loss = loss * mask_batch + + n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + return { + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss + } + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) + return loss diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 3395aa08f..0d0281690 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -66,68 +66,38 @@ def _build_input_queue( global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: - not_train = split != 'train' - per_device_batch_size = int(global_batch_size / N_GPUS) - - seq_len = self._seq_len # TODO: define it somewehere else? - dtype = torch.int32 # TODO: decide between int32 and int64. - - # Only create and iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - np_iter = super()._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + """Build an input queue for the given split.""" + from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="torch", + split=split) + seq_len = self._seq_len weights = None - - while True: - # Only iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - batch = next(np_iter) # pylint: disable=stop-iteration-return - inputs = torch.as_tensor( - batch['inputs'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - targets = torch.as_tensor( - batch['targets'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - - # Send batch to other devices when using DDP. - if USE_PYTORCH_DDP: - if not_train: - # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.tensor( - len(targets[0]), dtype=dtype, device=DEVICE) - dist.broadcast(per_device_batch_size, src=0) - # We don't broadcast the shard for RANK 0. - dist.broadcast(inputs[1:], src=0) - dist.broadcast(targets[1:], src=0) - - # RANK 0 extracts his shard. If not DDP, this just flattens. - inputs, targets = inputs[0], targets[0] - - else: - # Receive batch from rank 0. - if not_train: - # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((1,), dtype=dtype, device=DEVICE) + + dtype = torch.long + is_train = split == 'train' + + for batch in loader: + inputs, targets = batch + + if USE_PYTORCH_DDP: + if not is_train: + # During eval, the batch size of the remainder might be different + per_device_batch_size = torch.tensor( + len(targets[0]), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) - - # N_GPUS - 1 since we don't broadcast the shard for RANK 0. - inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=dtype, - device=DEVICE) - targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=dtype, - device=DEVICE) + + # Broadcast to all devices dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) - # RANK - 1 since we don't broadcast the shard for RANK 0. - inputs, targets = inputs[RANK - 1], targets[RANK - 1] + + if weights is None: + weights = torch.ones(inputs.shape[0], device=DEVICE) if weights is None: weights = torch.ones(per_device_batch_size, device=DEVICE) @@ -138,10 +108,51 @@ def _build_input_queue( } yield batch + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'output.weight' in param_name or 'output.bias' in param_name + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + model = params + logits, _ = self.model_fn( + model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(targets * log_probs) + return loss + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in PyTorch.""" + vocab_size = logits_batch.shape[-1] + + if len(label_batch.shape) == len(logits_batch.shape): + # One-hot labels + log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1) + loss = -torch.sum(label_batch * log_probs, dim=-1) + else: + # Dense labels + loss = torch.nn.functional.cross_entropy( + logits_batch, + label_batch, + reduction='none') + + if mask_batch is not None: + loss = loss * mask_batch + + n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + return { + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss + } diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index a06b17fdc..c10bf13e8 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -11,6 +11,7 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS @@ -21,10 +22,13 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 50257 - _seq_len: int = 512 + _seq_len: int = 5 + warmup_factor: float = 0.1 def __init__(self) -> None: - pass + super().__init__() + self._param_shapes = None + self._param_types = None @property def target_metric_name(self) -> str: @@ -36,14 +40,14 @@ def has_reached_validation_target(self, eval_result: float) -> bool: @property def validation_target_value(self) -> float: - pass + return 20.0 # Target perplexity - def has_reached_test_target(self, eval_result: float) -> bool: - return eval_result['test/ppl'] > self.test_target_value + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: + return eval_result['test/ppl'] <= self.test_target_value @property def test_target_value(self) -> float: - pass + return 20.0 # Target perplexity @property def loss_type(self) -> spec.LossType: @@ -51,23 +55,23 @@ def loss_type(self) -> spec.LossType: @property def num_train_examples(self) -> int: - pass + return 1000000 # Example size @property def num_eval_train_examples(self) -> int: - pass + return 10000 # Subset for evaluation @property def num_validation_examples(self) -> int: - pass + return 50000 @property def num_test_examples(self) -> int: - pass + return 50000 @property def eval_batch_size(self) -> int: - pass + return 8 @property def train_mean(self): @@ -79,16 +83,16 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - pass + return 3600 * 4 # 4 hours @property def eval_period_time_sec(self) -> int: - pass + return 600 # 10 minutes @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - pass + return 100000 @property def pre_ln(self) -> bool: @@ -116,13 +120,22 @@ def _build_input_queue(self, repeat_final_dataset: bool = False): """Build an input queue for the given split.""" - @abc.abstractmethod def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False) + + loss_dict = self.loss_fn(batch['targets'], logits) + return loss_dict['summed'] def _eval_model_on_split(self, split: str, @@ -145,9 +158,10 @@ def _eval_model_on_split(self, num_batches, repeat_final_dataset=True) + loss = 0.0 for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch) + loss += self._eval_batch(params, eval_batch, model_state, rng) if USE_PYTORCH_DDP: dist.all_reduce(loss) mean_loss = loss.item() / num_examples @@ -155,16 +169,11 @@ def _eval_model_on_split(self, # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. + @abc.abstractmethod def loss_fn( self, - label_batch: spec.Tensor, # Dense or one-hot labels. + label_batch: spec.Tensor, logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the (masked) loss function at (label_batch, logits_batch). - - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ - pass + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling.""" diff --git a/pyproject.toml b/pyproject.toml index f4ebdaee3..745c6c680 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ version_file = "algoperf/_version.py" [project.optional-dependencies] # All workloads full = [ - "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,lm]", ] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] @@ -96,6 +96,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] +lm = ["transformers", "datasets"] # Frameworks jax_core_deps = [ diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 49e46109b..c570e382b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -90,12 +90,6 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -# @functools.partial( -# jax.pmap, -# axis_name='batch', -# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), -# static_broadcasted_argnums=(0, 1), -# donate_argnums=(2, 3, 4)) def train_step(workload, opt_update_fn, model_state, @@ -272,6 +266,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'lm': + return 8 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/submission_runner.py b/submission_runner.py index fa300916e..fd1eb8259 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -250,7 +250,8 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_deepspeech', + 'lm' ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -712,7 +713,8 @@ def main(_): 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', - 'criteo1tb' + 'criteo1tb', + 'lm' ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' From 99c7b9b70a374a25d6ac29c4f9a0f7c95e57c1aa Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 12:46:53 -0400 Subject: [PATCH 56/68] add nanodo model --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 345 ++++++++++++++++++ algoperf/workloads/lm/lm_jax/workload.py | 56 ++- .../paper_baselines/adamw/jax/submission.py | 4 +- 3 files changed, 386 insertions(+), 19 deletions(-) create mode 100644 algoperf/workloads/lm/lm_jax/nanodo_model.py diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py new file mode 100644 index 000000000..d21fd5090 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -0,0 +1,345 @@ +# Self-contained version of the DecoderOnly Transformer from NanoDO + +import dataclasses +from functools import partial + +from flax import linen as nn +import jax +import jax.numpy as jnp + +# =========== Transformer Decoder-only Model ========== + + + +@dataclasses.dataclass +class DoConfig: + """Hyper-parameters for Transformer decoder-only.""" + + D: int # model/embed dim = qkv dim + H: int # num attention heads + L: int # max context/sequence length + N: int # number of transformer block layers + V: int # vocab size + F: int # FF inner dimension + kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() + embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", out_axis=0 + ) + dtype: jnp.dtype = jnp.float32 + rmsnorm_epsilon: float = 1e-6 + multiple_of: int = 256 + tie_embeddings: bool = True # Whether to tie input and output embeddings + + +class Mlp(nn.Module): + """Multilayer perceptron with GLU activation.""" + + cfg: DoConfig + + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + # Use Xavier uniform initialization explicitly + xavier_init = nn.initializers.xavier_uniform() + linear = partial( + nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype + ) + hidden_dim = cfg.multiple_of * ( + (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = linear(cfg.D)(x_BxLxF) + return x_BxLxD + +@partial(jax.jit, static_argnums=(0,1,2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack([ + jnp.cos(freqs)[None, :, None, :], + jnp.sin(freqs)[None, :, None, :] + ], axis=3) + + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) + +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack([ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1] + ], axis=-1) + + return rotated_x_r2.reshape(*x.shape) + + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) + + return rotated_q, rotated_k + + +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: DoConfig + + def setup(self): + cfg = self.cfg + assert cfg.D % cfg.H == 0, f"D {cfg.D} not divisible by H {cfg.H}" + self.Dh = cfg.D // cfg.H + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.H, self.Dh), + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + self.multilinear_query = self.multilinear(name="query") + self.multilinear_key = self.multilinear(name="key") + self.multilinear_value = self.multilinear(name="value") + self.output_projection = nn.DenseGeneral( + features=cfg.D, + name="attn_out_proj", + # axis=(-2, -1), # + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) + + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + + # Scale queries + q_BxLxHxDh /= self.Dh**0.5 + + # Compute attention scores + att_BxHxLxL = jnp.einsum("...qhd,...khd->...hqk", q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum("...hqk,...khd->...qhd", att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) + + return out_BxLxD + + +class TBlock(nn.Module): + """Transformer Block.""" + + docfg: DoConfig + + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg + + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD + + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) + + return x_BxLxD + z_BxLxD + + +class TransformerDo(nn.Module): + """Transformer decoder-only.""" + + docfg: DoConfig + + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.V, + features=cfg.D, + embedding_init=cfg.embed_init, + ) + + self.blocks = [TBlock(cfg) for _ in range(cfg.N)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.V, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name="output_proj" + ) + + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV + + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. + + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + batch_size = y_BxL.shape[0] + seq_len = y_BxL.shape[1] + + # Store original input + original_input = y_BxL + + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.L: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" + ) + + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) + + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] + + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) + model = TransformerDo(cfg) + + # Print model info + print(f"\nModel Configuration:") + print(f" - Model dimension (D): {cfg.D}") + print(f" - Number of heads (H): {cfg.H}") + print(f" - Max sequence length (L): {cfg.L}") + print(f" - Number of layers (N): {cfg.N}") + print(f" - Vocabulary size (V): {cfg.V}") + print(f" - Feed forward dimension (F): {cfg.F}") + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 + ) + + # Initialize model parameters + print("\nInitializing model parameters...") + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f"Total parameters: {param_count:,}") + + # Make a prediction (forward pass) + print("\nRunning forward pass...") + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print(f"\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)") + print(f"Output data type: {logits.dtype}") + + # Print sample logits (first 5 positions of the first sequence) + print("\nSample logits (first sequence, first 5 positions, first 5 values):") + for position in range(min(5, L)): + print(f" Position {position}: {logits[0, position, :5]}") + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + # Test the predict function + print("\nTesting predict function...") + # Use a shorter + short_seq = x_BxL[:, :10] + print(f"Input sequence shape: {short_seq.shape}") + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 7cb50302f..9fdfe6f60 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -10,7 +10,8 @@ from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.lm_jax.nanodo_model import ( + TransformerDo, DoConfig, init_rope, apply_rope) from algoperf.workloads.lm.input_pipeline import get_hf_dataloader @@ -42,12 +43,22 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - self._model = LinearModel(vocab_size=self._vocab_size) - input_shape = (1, self._seq_len, self._vocab_size) + # Initialize NanoDO transformer model + cfg = DoConfig( + D=512, # model dim + H=8, # num heads + L=self._seq_len, + N=6, # num layers + V=self._vocab_size, + F=2048, # feedforward dim + dtype=jnp.float32 + ) + self._model = TransformerDo(cfg) + input_shape = (1, self._seq_len) # For token IDs + params_rng, init_rng = jax.random.split(rng) - print(params_rng) - # variables = model.init(init_rng, jnp.ones(input_shape, jnp.float32)) - variables = jax.jit(self._model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32)) + variables = jax.jit(self._model.init)({'params': params_rng}, + jnp.ones(input_shape, jnp.int32)) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -66,6 +77,11 @@ def model_fn( del mode, rng, update_batch_norm, model_state inputs = batch['inputs'] + + # Convert one-hot inputs to token IDs if needed + if inputs.ndim == 3: # one-hot encoded + inputs = jnp.argmax(inputs, axis=-1) + logits = self._model.apply({'params': params}, inputs) return logits, None @@ -76,23 +92,29 @@ def loss_fn( mask_batch: Optional[spec.Tensor] = None, label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in JAX.""" - vocab_size = logits_batch.shape[-1] + # Convert one-hot labels to token IDs if needed + if len(label_batch.shape) == len(logits_batch.shape): # one-hot + label_batch = jnp.argmax(label_batch, axis=-1) - if len(label_batch.shape) == len(logits_batch.shape): - # One-hot labels - loss = -jnp.sum(label_batch * jax.nn.log_softmax(logits_batch, axis=-1)) - else: - # Dense labels - loss = -jax.nn.log_softmax(logits_batch)[jnp.arange(label_batch.shape[0]), label_batch] + # Reshape for sequence modeling + logits = logits_batch.reshape(-1, logits_batch.shape[-1]) + labels = label_batch.reshape(-1) + + # Compute cross-entropy loss + loss = -jnp.sum( + jax.nn.log_softmax(logits)[jnp.arange(labels.shape[0]), labels]) if mask_batch is not None: - loss = loss * mask_batch + mask = mask_batch.reshape(-1) + loss = loss * mask + n_valid = mask.sum() + else: + n_valid = labels.shape[0] - n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] return { - 'summed': loss.sum(), + 'summed': loss, 'n_valid_examples': n_valid, - 'per_example': loss + 'per_example': loss / n_valid # Return per-token loss } def is_output_params(self, param_name: str) -> bool: diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 6c6d19ef8..dca9a6b95 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -75,7 +75,6 @@ def _loss_fn(params): spec.ForwardPassMode.TRAIN, rng, update_batch_norm=True,) - jax.debug.print("logits: {logits}", logits=logits) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, @@ -163,7 +162,6 @@ def update_params( replicated, # loss replicated # grad_norm )) - # print(batch) new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, @@ -229,6 +227,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From 706d9f74046a0f1c90256ae584b45e30a38e4349 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 13:26:15 -0400 Subject: [PATCH 57/68] torch model --- algoperf/param_utils.py | 2 + .../workloads/lm/lm_pytorch/plainlm_model.py | 298 ++++++++++++++++++ algoperf/workloads/lm/lm_pytorch/workload.py | 57 ++-- .../adamw/pytorch/submission.py | 2 + 4 files changed, 341 insertions(+), 18 deletions(-) create mode 100644 algoperf/workloads/lm/lm_pytorch/plainlm_model.py diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 05d882404..24f981546 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -43,6 +43,8 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_BIAS elif 'in_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'qkv' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV elif 'kv_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py new file mode 100644 index 000000000..627a0e16d --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -0,0 +1,298 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn +from dataclasses import dataclass +from typing import Tuple + + + +@dataclass +class ModelConfig: + vocab_size: int + seq_len: int + dim: int + expand: float + n_layers: int + n_heads: int + rmsnorm_eps: float = 1e-6 + tie_embeddings: bool = False + + +class MLP(nn.Module): + + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + + # Initialize with Xavier uniform + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis(dim: int, + end: int, + theta: float = 10000.0, + condense_ratio: int = 1): + inv_freqs = 1.0 / (theta**(torch.arange( + 0, dim, 2, dtype=torch.float32, device=torch.device("cpu")) / dim)) + t = torch.arange(end, dtype=torch.float32, + device=inv_freqs.device) / condense_ratio + freqs = torch.outer(t, inv_freqs).float() + return torch.stack([ + torch.cos(freqs)[None, :, None, :], + torch.sin(freqs)[None, :, None, :] + ], + dim=4) + + +def apply_rotary_emb_complex_like( + q: torch.Tensor, k: torch.Tensor, + freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - + qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + + +class Attention(nn.Module): + + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.dim % cfg.n_heads == 0 + self.dim = cfg.dim + self.n_heads = cfg.n_heads + self.head_dim = cfg.dim // cfg.n_heads + + self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) + self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) + + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + k = k.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + v = v.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis) # (bsz, seqlen, nh, h_dim) + + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True) # (bsz, nh, seqlen, h_dim) + + out = out.transpose(1, 2).contiguous().view(bsz, seqlen, + d) # (bsz, seqlen, d) + + return self.w_out(out) + + +class Block(nn.Module): + + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) + self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.layer_id = layer_id + + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, cfg): + super().__init__() + self.n_layers = cfg.n_layers + self.cfg = cfg + head_dim = cfg.dim // cfg.n_heads + assert cfg.dim % cfg.n_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.n_layers)]) + self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer('freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0:cfg.seq_len], + persistent=False) + + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() + + # Move model to device (which will also move freqs_cis) + if torch.cuda.is_available(): + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.dim // self.cfg.n_heads + new_freqs = precompute_freqs_cis(head_dim, max(L, self.cfg.seq_len), 500000) + self.register_buffer('freqs_cis', new_freqs[0:max(L, self.cfg.seq_len)], persistent=False) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + # For debugging + predictions = [] + + batch_size = x.shape[0] + seq_len = x.shape[1] + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Print top 5 tokens for debugging + if i == 0: + print("\nPyTorch detailed prediction:") + top5_values, top5_indices = torch.topk(next_token_logits[0], 5) + for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())): + prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item() + print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}") + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + predictions.append(next_token.item()) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + print(f" Full predictions step by step: {predictions}") + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith("fc2.weight"): # mlp/glu output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + if n.endswith("w_out.weight"): # attn output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print("Initializing transformer model and running forward pass...") + + seq_length = 512 + + # Define model configuration + config = ModelConfig( + vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + dim=768, # Embedding dimension + expand=4.0, # MLP expansion factor + n_layers=12, # Number of transformer layers + n_heads=12, # Number of attention heads + rmsnorm_eps=1e-6, # RMSNorm epsilon + tie_embeddings=True # Tie embedding and output weights + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 0d0281690..45ad0828f 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -11,7 +11,7 @@ from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_pytorch.models import LinearLayer +from algoperf.workloads.lm.lm_pytorch.plainlm_model import Transformer, ModelConfig USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -26,11 +26,23 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: if hasattr(self, '_model'): - self._model.reset_parameters() + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() return self._model, None torch.manual_seed(rng[0]) - self._model = LinearLayer(vocab_size=self._vocab_size) + cfg = ModelConfig( + vocab_size=self._vocab_size, + seq_len=self._seq_len, + dim=512, # Model dimension + expand=4, # MLP expansion factor + n_layers=6, # Number of transformer layers + n_heads=8, # Number of attention heads + rmsnorm_eps=1e-6, + tie_embeddings=True + ) + self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) self._model.to(DEVICE) @@ -46,15 +58,20 @@ def init_model_fn( def model_fn( self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del model_state, rng, update_batch_norm # Not used for linear model + del model_state, rng, update_batch_norm model = params - inputs = batch['inputs'].float() # Convert one-hot to float + + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) + logits = model(inputs) return logits, None @@ -83,13 +100,14 @@ def _build_input_queue( is_train = split == 'train' for batch in loader: - inputs, targets = batch + inputs = batch['inputs'] + targets = batch['targets'] if USE_PYTORCH_DDP: if not is_train: # During eval, the batch size of the remainder might be different per_device_batch_size = torch.tensor( - len(targets[0]), dtype=dtype, device=DEVICE) + targets.shape[0], dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # Broadcast to all devices @@ -97,10 +115,8 @@ def _build_input_queue( dist.broadcast(targets, src=0) if weights is None: - weights = torch.ones(inputs.shape[0], device=DEVICE) - - if weights is None: - weights = torch.ones(per_device_batch_size, device=DEVICE) + batch_size = targets.shape[0] if not USE_PYTORCH_DDP else per_device_batch_size.item() + weights = torch.ones((batch_size, seq_len), device=DEVICE) batch = { 'inputs': inputs, 'targets': targets, @@ -110,7 +126,7 @@ def _build_input_queue( def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" - return 'output.weight' in param_name or 'output.bias' in param_name + return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name def _eval_batch(self, params: spec.ParameterContainer, @@ -121,11 +137,17 @@ def _eval_batch(self, model = params logits, _ = self.model_fn( model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - targets = batch['targets'] - # Calculate cross-entropy loss - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - loss = -torch.sum(targets * log_probs) + # Handle both one-hot and token ID targets + targets = batch['targets'] + if targets.dim() == 3: # one-hot + loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) + else: # token IDs + loss = torch.nn.functional.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + reduction='sum' + ) return loss def loss_fn( self, @@ -146,7 +168,6 @@ def loss_fn( logits_batch, label_batch, reduction='none') - if mask_batch is not None: loss = loss * mask_batch diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 21d9b6b57..bdeaaf95b 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -173,6 +173,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From c335e341913dc6b1a747f2d3407e71a8d8e66ab6 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 29 May 2025 14:22:50 +0000 Subject: [PATCH 58/68] lm workload dataset integration in jax --- .../workloads/cifar/cifar_jax/workload.py | 11 - algoperf/workloads/lm/input_pipeline.py | 12 +- algoperf/workloads/lm/lm_jax/models.py | 3 +- algoperf/workloads/lm/lm_jax/workload.py | 68 +++- algoperf/workloads/lm/lm_pytorch/workload.py | 49 +-- algoperf/workloads/lm/workload.py | 313 +++++++++--------- .../nesterov/jax/submission.py | 8 +- submission_runner.py | 6 +- 8 files changed, 261 insertions(+), 209 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f827fac87..fd990eeaa 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -71,17 +71,6 @@ def _build_input_queue( cache, repeat_final_dataset) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics - # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index cc658501e..8f68fcb55 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -87,19 +87,19 @@ def batch_iterator(): tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield inputs, targets - + batch = { + "inputs": inputs, + "targets": targets, + } + yield batch return batch_iterator() def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, - vocab_size: int, global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False, - vocab_path: Optional[str] = None): + num_batches: Optional[int] = None): """Load HF dataset and return a TF dataset.""" dataset_path = os.path.join(data_dir, split) diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py index edfc102fa..7913f2c67 100644 --- a/algoperf/workloads/lm/lm_jax/models.py +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -14,5 +14,6 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: return nn.Dense( self.vocab_size, kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros + bias_init=nn.initializers.zeros, + name="output" )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 30b0c7867..6ad0e7d3d 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,16 +2,36 @@ from typing import Dict, Optional, Tuple +import jax import jax.numpy as jnp +import optax from flax import jax_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader, get_lm_dataset class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using pre-cached FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_lm_dataset( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + return loader def init_model_fn( self, @@ -21,14 +41,15 @@ def init_model_fn( model = LinearModel(vocab_size=self._vocab_size) input_shape = (1, self._seq_len, self._vocab_size) - variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) - model_state, params = variables.pop('params') - + params_rng, init_rng = jax.random.split(rng) + variables = jax.jit(model.init)({'params': params_rng}, + jnp.ones(input_shape, jnp.float32)) + params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) - + params = sharding_utils.shard_replicated(params) + model_state = None + self._model = model return params, model_state def model_fn( @@ -40,15 +61,40 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm # Not used for linear model - inputs = batch['inputs'] - logits = self._model.apply({'params': params, **model_state}, inputs) - return logits, model_state + del mode, rng, update_batch_norm, model_state + inputs = jax.nn.one_hot(batch['inputs'], self._vocab_size, axis=-1) + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, # One-hot labels. + logits_batch: spec.Tensor, # Dense logits. + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: Optional[float] = 0.0) -> Dict[str, spec.Tensor]: + del mask_batch, label_smoothing + logits_flat = logits_batch.reshape(-1, self._vocab_size) + targets = jax.nn.one_hot(label_batch, self._vocab_size, axis=-1) + targets_flat = targets.reshape(-1, self._vocab_size) + # Cross-entropy loss + loss = -jnp.sum(targets_flat * jax.nn.log_softmax(logits_flat, axis=-1)) + n_valid_examples = logits_flat.shape[0] + return {'summed': loss, 'n_valid_examples': n_valid_examples} + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) + return loss diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 3395aa08f..2c6862160 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -66,35 +66,30 @@ def _build_input_queue( global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: - not_train = split != 'train' - per_device_batch_size = int(global_batch_size / N_GPUS) - - seq_len = self._seq_len # TODO: define it somewehere else? - dtype = torch.int32 # TODO: decide between int32 and int64. - - # Only create and iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - np_iter = super()._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + """Build an input queue for the given split.""" + from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="torch", + split=split) + seq_len = self._seq_len weights = None - + while True: # Only iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: - batch = next(np_iter) # pylint: disable=stop-iteration-return + batch = next(dataset_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( batch['inputs'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + device=DEVICE) # (N_GPUS, per_device_batch_size, seq_len) targets = torch.as_tensor( batch['targets'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + device=DEVICE) # (N_GPUS, per_device_batch_size, seq_len) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: @@ -138,10 +133,22 @@ def _build_input_queue( } yield batch + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'output.weight' in param_name or 'output.bias' in param_name + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + model = params + logits, _ = self.model_fn( + model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(targets * log_probs) + return loss diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index a06b17fdc..e6b33e3e4 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -11,160 +11,171 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS -USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +USE_PYTORCH_DDP = "LOCAL_RANK" in os.environ class BaseLmWorkload(spec.Workload): - """LM workload.""" - - _vocab_size: int = 50257 - _seq_len: int = 512 - - def __init__(self) -> None: - pass - - @property - def target_metric_name(self) -> str: - """The name of the target metric (useful for scoring/processing code).""" - return 'ppl' - - def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/ppl'] > self.validation_target_value - - @property - def validation_target_value(self) -> float: - pass - - def has_reached_test_target(self, eval_result: float) -> bool: - return eval_result['test/ppl'] > self.test_target_value - - @property - def test_target_value(self) -> float: - pass - - @property - def loss_type(self) -> spec.LossType: - return spec.LossType.SOFTMAX_CROSS_ENTROPY - - @property - def num_train_examples(self) -> int: - pass - - @property - def num_eval_train_examples(self) -> int: - pass - - @property - def num_validation_examples(self) -> int: - pass - - @property - def num_test_examples(self) -> int: - pass - - @property - def eval_batch_size(self) -> int: - pass - - @property - def train_mean(self): - raise NotImplementedError - - @property - def train_stddev(self): - raise NotImplementedError - - @property - def max_allowed_runtime_sec(self) -> int: - pass - - @property - def eval_period_time_sec(self) -> int: - pass - - @property - def step_hint(self) -> int: - """Approx. steps the baseline can do in the allowed runtime budget.""" - pass - - @property - def pre_ln(self) -> bool: - return True - - @property - def attention_temp(self) -> float: - return 1.0 - - @property - def activation(self) -> str: - return 'silu' - - @property - def glu(self) -> bool: - return True - - @abc.abstractmethod - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): - """Build an input queue for the given split.""" - - @abc.abstractmethod - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: - """Run a full evaluation of the model.""" - num_batches = int(math.ceil(num_examples / global_batch_size)) - if split not in self._eval_iters: - # These iterators will repeat indefinitely. - self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset=True) - - for _ in range(num_batches): - eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch) - if USE_PYTORCH_DDP: - dist.all_reduce(loss) - mean_loss = loss.item() / num_examples - return {'loss': mean_loss} - - # Does NOT apply regularization, which is left to the submitter to do in - # `update_params`. - def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the (masked) loss function at (label_batch, logits_batch). - - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ - pass + """LM workload.""" + + _vocab_size: int = 50257 + _seq_len: int = 512 + + def __init__(self) -> None: + pass + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return "ppl" + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result["validation/ppl"] > self.validation_target_value + + @property + def validation_target_value(self) -> float: + pass + + def has_reached_test_target(self, eval_result: float) -> bool: + return eval_result["test/ppl"] > self.test_target_value + + @property + def test_target_value(self) -> float: + pass + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + pass + + @property + def num_eval_train_examples(self) -> int: + pass + + @property + def num_validation_examples(self) -> int: + pass + + @property + def num_test_examples(self) -> int: + pass + + @property + def eval_batch_size(self) -> int: + return 8 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + pass + + @property + def eval_period_time_sec(self) -> int: + pass + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + # FIXME: should replace this with a real value later. + return 10000 + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return "silu" + + @property + def glu(self) -> bool: + return True + + @abc.abstractmethod + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): + """Build an input queue for the given split.""" + + @abc.abstractmethod + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True, + ) + + loss = 0.0 + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + loss += self._eval_batch(params, eval_batch, model_state, rng) + if USE_PYTORCH_DDP: + dist.all_reduce(loss) + mean_loss = loss.item() / num_examples + return {"loss": mean_loss} + + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the (masked) loss function at (label_batch, logits_batch). + + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ + pass diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 49e46109b..c570e382b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -90,12 +90,6 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -# @functools.partial( -# jax.pmap, -# axis_name='batch', -# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), -# static_broadcasted_argnums=(0, 1), -# donate_argnums=(2, 3, 4)) def train_step(workload, opt_update_fn, model_state, @@ -272,6 +266,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'lm': + return 8 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/submission_runner.py b/submission_runner.py index fa300916e..fd1eb8259 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -250,7 +250,8 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_deepspeech', + 'lm' ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -712,7 +713,8 @@ def main(_): 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', - 'criteo1tb' + 'criteo1tb', + 'lm' ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' From af8cce4d61e7f79916d7293127121ebaa4a4d7ce Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 5 Jun 2025 03:20:46 +0000 Subject: [PATCH 59/68] set package versions for transformers and datasets --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 745c6c680..5e9c21f47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] -lm = ["transformers", "datasets"] +lm = ["transformers==4.25.4", "datasets==3.6.0"] # Frameworks jax_core_deps = [ From d68c54e0aa023570abc94cea97f5757bfb0baca8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 5 Jun 2025 04:02:41 +0000 Subject: [PATCH 60/68] use train_test_split method to shuffle and split fineweb-edu dataset --- dataset/dataset_setup.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 6587f1439..7a83a03f6 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -770,18 +770,10 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - # Find how many entries to take from dataset to have val_tokens in validation set. - val_tokens = 10_000_000 # TODO: decide this value. - tokens_accumulated, num_examples_for_val = 0, 0 - for example in tokenized_dataset: - tokens_accumulated += len(example['input_ids']) - num_examples_for_val += 1 - if tokens_accumulated >= val_tokens: - break # Split in train and valid. - val_dataset = tokenized_dataset.select(range(num_examples_for_val)) - train_dataset = tokenized_dataset.select( - range(num_examples_for_val, len(tokenized_dataset))) + dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) + train_dataset = dataset_split_dict['train'] + val_dataset = dataset_split_dict['test'] # Concat in chunks of max_seq_len. # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. From 9737367473f35b206333edc46f9c193ec8dda821 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 19:45:32 +0000 Subject: [PATCH 61/68] modifications to fwedu datasetup --- dataset/dataset_setup.py | 164 +++++++++++++++++---------------------- 1 file changed, 73 insertions(+), 91 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 7a83a03f6..584189c4a 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -191,6 +191,7 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') +flags.DEFINE_boolean('skip_tokenization', False, 'Skip Fineweb-edu tokenization.') FLAGS = flags.FLAGS @@ -707,106 +708,87 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) -def download_finewebedu(data_dir, tmp_dir=None): +def download_finewebedu(data_dir, + tmp_dir=None, + skip_download=False, + skip_tokenization=False): """Download FineWebEdu-10B.""" - data_dir = os.path.join(data_dir, 'finewebedu') - tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' - cache_dir = os.path.join(tmp_dir, - 'lm') if tmp_dir is not None else os.path.expanduser( - '~/.cache/huggingface/datasets') - - _maybe_mkdir(data_dir) - _maybe_mkdir(tmp_dir) - _maybe_mkdir(cache_dir) - - os.environ["TMPDIR"] = tmp_dir - - ds = hf_datasets.load_dataset( - 'HuggingFaceFW/fineweb-edu', - name='sample-10BT', - split='train', - cache_dir=cache_dir) - # TODO (nico): maybe save intermediate dataset to avoid re-downloading - # and allow re-chunking with different seq_len? - - # Shuffle so that multiproc has shards of similar size. - ds = ds.shuffle(seed=1996) - - seq_len = 2048 - max_seq_length = seq_len + 1 - map_setup = dict(batched=True, batch_size=1024, num_proc=8) - - # Tokenize - lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') - logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") - - def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - add_eos = lambda seq: (seq + lm_tokenizer.eos_token) if seq else seq - add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] - return lm_tokenizer( - add_eos_batched(examples["text"]), - return_special_tokens_mask=False, - return_attention_mask=False) - - lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization - logging.info(f"Tokenizing...") - tokenized_dataset = ds.map( - tokenize, - remove_columns=[ - 'text', - 'id', - 'dump', - 'url', - 'file_path', - 'language', - 'language_score', - 'token_count', - 'score', - 'int_score' - ], - **map_setup) - lm_tokenizer.model_max_length = seq_len - - tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) + if not skip_download: + data_dir = os.path.join(data_dir, 'finewebedu') + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = os.path.join(tmp_dir, + 'lm') if tmp_dir is not None else os.path.expanduser( + '~/.cache/huggingface/datasets') + + _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) + + os.environ["TMPDIR"] = tmp_dir + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir) + ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + else: + ds = hf_datasets.load_from_disk(tmp_dir, 'fwedu_10B_raw') + + if not skip_tokenization: + # Tokenize + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + + def add_eos(seq): + return seq + lm_tokenizer.eos_token if seq else seq + + def add_eos_batched(seqs): + return [add_eos(seq) for seq in seqs] + + return lm_tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False) + + lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + logging.info("Tokenizing...") + tokenized_dataset = ds.map( + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score' + ],) + + tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + else: + tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset = dataset_split_dict['train'] val_dataset = dataset_split_dict['test'] - # Concat in chunks of max_seq_len. - # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. - def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - """Concatenate text and generate chunks of max_seq_length""" - concatenated_examples = { - k: list(itertools.chain(*examples[k])) for k in examples.keys() - } - total_length = len(concatenated_examples[list(examples.keys())[0]]) - if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length - result = { - k: [ - t[i:i + max_seq_length] - for i in range(0, total_length, max_seq_length) - ] for k, - t in concatenated_examples.items() - } - return result - - # Concat text in validation and train sets. - logging.info(f"Concatenating and chunking...") - val_dataset = val_dataset.map(concat_chunck, **map_setup) - train_dataset = train_dataset.map(concat_chunck, **map_setup) - logging.info( - f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") - logging.info( - f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}" - ) + # Convert to tensorflow_datasets.Dataset objects + train_dataset = train_dataset.to_tf_dataset() + val_dataset = train_dataset.to_tf_dataset() # Save datasets - train_dataset.save_to_disk(os.path.join(data_dir, f"train")) - val_dataset.save_to_disk(os.path.join(data_dir, f"val")) + train_dataset.Save(os.path.join(data_dir, "train")) + val_dataset.save(os.path.join(data_dir, "val")) + + return def main(_): @@ -893,7 +875,7 @@ def main(_): if FLAGS.all or FLAGS.finewebedu: logging.info('Downloading FineWebEdu-10B...') - download_finewebedu(data_dir, tmp_dir) + download_finewebedu(data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization) # pylint: enable=logging-format-interpolation From 1bf0750e094a695176e8e3bc45ffd979abe9e237 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 19:46:26 +0000 Subject: [PATCH 62/68] rename fwedu data dir --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 584189c4a..ae27aab18 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -715,7 +715,7 @@ def download_finewebedu(data_dir, """Download FineWebEdu-10B.""" if not skip_download: - data_dir = os.path.join(data_dir, 'finewebedu') + data_dir = os.path.join(data_dir, 'fineweb_edu_10B') tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser( From a33339117b4c79d5fa946f4f7ed029087ab5a630 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 20:46:21 +0000 Subject: [PATCH 63/68] fix --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index ae27aab18..289a1faa6 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -734,7 +734,7 @@ def download_finewebedu(data_dir, cache_dir=cache_dir) ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) else: - ds = hf_datasets.load_from_disk(tmp_dir, 'fwedu_10B_raw') + ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) if not skip_tokenization: # Tokenize From 05dc4dd7102670cebb8ac3a8875b34387d57b9b6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 21:22:57 +0000 Subject: [PATCH 64/68] add back batch mapping in tokenization for fwedu --- dataset/dataset_setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 289a1faa6..f50274615 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -769,7 +769,10 @@ def add_eos_batched(seqs): 'token_count', 'score', 'int_score' - ],) + ], + batched=True, + batch_size=1024, + num_proc=8) tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) else: From b374cf8db62e99e1594dea90b46a7f69a5bb04c6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:12:24 +0000 Subject: [PATCH 65/68] debugging --- dataset/dataset_setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index f50274615..2c46f4ebc 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -779,9 +779,11 @@ def add_eos_batched(seqs): tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. + print(type(tokenized_dataset)) dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset = dataset_split_dict['train'] val_dataset = dataset_split_dict['test'] + print(type(train_dataset)) # Convert to tensorflow_datasets.Dataset objects train_dataset = train_dataset.to_tf_dataset() From c0c1e3c32c46d65cb7511891b32429aeeb05f90c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:13:48 +0000 Subject: [PATCH 66/68] debugging --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 2c46f4ebc..c18e72ea4 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -776,7 +776,7 @@ def add_eos_batched(seqs): tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) else: - tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. print(type(tokenized_dataset)) From f76dc392fa83a1da25194d401aa03a9dd6dc9c6a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:23:24 +0000 Subject: [PATCH 67/68] debugging --- dataset/dataset_setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index c18e72ea4..414b78609 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -778,6 +778,7 @@ def add_eos_batched(seqs): else: tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenized_dataset.to_tf_dataset() # Split in train and valid. print(type(tokenized_dataset)) dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) From e805fa7997daae83deea4e5336801af195270c1a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:45:07 +0000 Subject: [PATCH 68/68] use tfds to shuffle and split dataset --- dataset/dataset_setup.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 414b78609..747d06d27 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -778,20 +778,18 @@ def add_eos_batched(seqs): else: tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) - tokenized_dataset.to_tf_dataset() - # Split in train and valid. - print(type(tokenized_dataset)) - dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) - train_dataset = dataset_split_dict['train'] - val_dataset = dataset_split_dict['test'] - print(type(train_dataset)) - # Convert to tensorflow_datasets.Dataset objects - train_dataset = train_dataset.to_tf_dataset() - val_dataset = train_dataset.to_tf_dataset() + tokenized_dataset = tokenized_dataset.to_tf_dataset() - # Save datasets - train_dataset.Save(os.path.join(data_dir, "train")) + # Shuffle dataset + dataset_size = tokenized_dataset.cardinality().numpy() + shuffled_dataset = tokenized_dataset.shuffle(dataset_size, seed=0) + train_size = int(0.9 * dataset_size) + train_dataset = shuffled_dataset.take(train_size) + val_dataset = shuffled_dataset.skip(train_size) + + # Split in train and valid. + train_dataset.save(os.path.join(data_dir, "train")) val_dataset.save(os.path.join(data_dir, "val")) return