Skip to content

Commit 89f929b

Browse files
authored
Move _jax_forward and _jax_backward inside j2t_autograd to avoid cache collisions (#9585)
1 parent c0eeb57 commit 89f929b

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

torchax/torchax/interop.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,36 @@ def j2t_autograd(fn, call_jax=call_jax):
237237
the PyTorch autograd framework by saving the residuals into the context object.
238238
"""
239239

240+
# NOTE(qihqi): This function cannot be inlined from the callsite
241+
# Becuase if it does, then it won't hit the compilation cache for
242+
# call_jax. Call jax uses functions' id as key.
243+
# It is nested inside j2t_autograd to ensure it gets a unique ID for each
244+
# wrapped pure function, preventing cache collisions between different pure modules.
245+
def _jax_forward(fn, other, tree_def, tensors):
246+
"""JAX function to compute output and vjp function.
247+
248+
primals should be a tuple (args, kwargs).
249+
"""
250+
import jax
251+
from jax.tree_util import tree_flatten, tree_unflatten
252+
253+
def fn_wrapper(*tensors):
254+
# Reconstruct the original args and kwargs
255+
flat_inputs = util.merge(tensors, other)
256+
args, kwargs = tree_unflatten(tree_def, flat_inputs)
257+
return fn(*args, **kwargs)
258+
259+
return jax.vjp(fn_wrapper, *tensors)
260+
261+
def _jax_backward(vjp_spec, saved_tensors, grad_out):
262+
"""JAX function to compute input gradients.
263+
264+
Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
265+
"""
266+
from jax.tree_util import tree_unflatten
267+
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
268+
return fun_vjp(grad_out)
269+
240270
@wraps(fn)
241271
def inner(*args, **kwargs):
242272
from jax.tree_util import tree_flatten
@@ -290,36 +320,6 @@ def backward(ctx, *grad_out):
290320
return inner
291321

292322

293-
# NOTE(qihqi): This function cannot be inlined from the callsite
294-
# Becuase if it does, then it won't hit the compilation cache for
295-
# call_jax. Call jax uses functions' id as key.
296-
def _jax_forward(fn, other, tree_def, tensors):
297-
"""JAX function to compute output and vjp function.
298-
299-
primals should be a tuple (args, kwargs).
300-
"""
301-
import jax
302-
from jax.tree_util import tree_flatten, tree_unflatten
303-
304-
def fn_wrapper(*tensors):
305-
# Reconstruct the original args and kwargs
306-
flat_inputs = util.merge(tensors, other)
307-
args, kwargs = tree_unflatten(tree_def, flat_inputs)
308-
return fn(*args, **kwargs)
309-
310-
return jax.vjp(fn_wrapper, *tensors)
311-
312-
313-
def _jax_backward(vjp_spec, saved_tensors, grad_out):
314-
"""JAX function to compute input gradients.
315-
316-
Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
317-
"""
318-
from jax.tree_util import tree_unflatten
319-
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
320-
return fun_vjp(grad_out)
321-
322-
323323
fori_loop = torch_view(jax.lax.fori_loop)
324324

325325

0 commit comments

Comments
 (0)