From 6a0415a8f638c756d0b4193feb42f4ef7e6efab9 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Mon, 10 Nov 2025 14:44:53 +0000 Subject: [PATCH] [remat] Make prevent_cse's tuple form account for consts --- jax/_src/ad_checkpoint.py | 3 ++- tests/api_test.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 658875e8c97c..2c4c6017580d 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -363,7 +363,8 @@ def fun_remat(*args, **kwargs): in_avals = [core.shaped_abstractify(x) for x in args_flat] jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals), debug) if isinstance(prevent_cse, tuple): - cse = (*broadcast_prefix(prevent_cse, (args, kwargs) if kwargs else args),) + cse_args = (tuple(args), kwargs) if kwargs else tuple(args) + cse = (False,) * len(consts) + tuple(broadcast_prefix(prevent_cse, cse_args)) else: cse = prevent_cse out_flat = remat_p.bind( diff --git a/tests/api_test.py b/tests/api_test.py index edd00d0bc5cb..855b8718f26a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7117,7 +7117,9 @@ def make_weight(i): def test_remat_partial_cse_prevention(self): @partial(jax.remat, prevent_cse=(False, True)) def layer(W, x): - return x @ W + res = x @ W + res += jnp.array([1.0, 2.0, 3.0]) # ensure the jaxpr also contains a const + return res def net(Ws, x): for W in Ws: