diff --git a/learned_optimization/tasks/es_wrapper.py b/learned_optimization/tasks/es_wrapper.py index 128efb1..044512d 100644 --- a/learned_optimization/tasks/es_wrapper.py +++ b/learned_optimization/tasks/es_wrapper.py @@ -105,15 +105,7 @@ def fn(theta, *args, es_key=None, **kwargs): if has_aux: losses, aux = aux_and_losses if not vec_aux: - # Avoid degraded performance under the new jax.pmap. See - # https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays. - if jax.config.jax_pmap_shmap_merge: - aux = jax.tree_util.tree_map( - lambda x: x.addressable_shards[0].data.squeeze(0), - aux_and_losses[1], - ) - else: - aux = jax.tree_util.tree_map(lambda x: x[0], aux_and_losses[1]) + aux = jax.tree_util.tree_map(lambda x: x[0], aux_and_losses[1]) else: losses = aux_and_losses @@ -166,14 +158,7 @@ def new_vmap(key): keys = jax.random.split(es_key, n_pairs) if has_aux: (value, aux), grad = jax.vmap(new_vmap)(keys) - # Avoid degraded performance under the new jax.pmap. See - # https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays. - if jax.config.jax_pmap_shmap_merge: - aux = jax.tree_util.tree_map( - lambda x: x.addressable_shards[0].data.squeeze(0), aux - ) - else: - aux = jax.tree_util.tree_map(lambda x: x[0], aux) + aux = jax.tree_util.tree_map(lambda x: x[0], aux) else: value, grad = jax.vmap(new_vmap)(keys)