From 7b5c894d3b1956571e0f8c3cf3eae70e184c66fd Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 2 Feb 2026 13:03:05 -0800 Subject: [PATCH] [pmap] Remove conditional jax.config.jax_pmap_shmap_merge logic as part of last phase of jax.pmap clean-up. PiperOrigin-RevId: 864486731 --- learned_optimization/tasks/es_wrapper.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/learned_optimization/tasks/es_wrapper.py b/learned_optimization/tasks/es_wrapper.py index 128efb1..044512d 100644 --- a/learned_optimization/tasks/es_wrapper.py +++ b/learned_optimization/tasks/es_wrapper.py @@ -105,15 +105,7 @@ def fn(theta, *args, es_key=None, **kwargs): if has_aux: losses, aux = aux_and_losses if not vec_aux: - # Avoid degraded performance under the new jax.pmap. See - # https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays. - if jax.config.jax_pmap_shmap_merge: - aux = jax.tree_util.tree_map( - lambda x: x.addressable_shards[0].data.squeeze(0), - aux_and_losses[1], - ) - else: - aux = jax.tree_util.tree_map(lambda x: x[0], aux_and_losses[1]) + aux = jax.tree_util.tree_map(lambda x: x[0], aux_and_losses[1]) else: losses = aux_and_losses @@ -166,14 +158,7 @@ def new_vmap(key): keys = jax.random.split(es_key, n_pairs) if has_aux: (value, aux), grad = jax.vmap(new_vmap)(keys) - # Avoid degraded performance under the new jax.pmap. See - # https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays. - if jax.config.jax_pmap_shmap_merge: - aux = jax.tree_util.tree_map( - lambda x: x.addressable_shards[0].data.squeeze(0), aux - ) - else: - aux = jax.tree_util.tree_map(lambda x: x[0], aux) + aux = jax.tree_util.tree_map(lambda x: x[0], aux) else: value, grad = jax.vmap(new_vmap)(keys)