@@ -237,6 +237,36 @@ def j2t_autograd(fn, call_jax=call_jax):
237
237
the PyTorch autograd framework by saving the residuals into the context object.
238
238
"""
239
239
240
+ # NOTE(qihqi): This function cannot be inlined from the callsite
241
+ # Becuase if it does, then it won't hit the compilation cache for
242
+ # call_jax. Call jax uses functions' id as key.
243
+ # It is nested inside j2t_autograd to ensure it gets a unique ID for each
244
+ # wrapped pure function, preventing cache collisions between different pure modules.
245
+ def _jax_forward (fn , other , tree_def , tensors ):
246
+ """JAX function to compute output and vjp function.
247
+
248
+ primals should be a tuple (args, kwargs).
249
+ """
250
+ import jax
251
+ from jax .tree_util import tree_flatten , tree_unflatten
252
+
253
+ def fn_wrapper (* tensors ):
254
+ # Reconstruct the original args and kwargs
255
+ flat_inputs = util .merge (tensors , other )
256
+ args , kwargs = tree_unflatten (tree_def , flat_inputs )
257
+ return fn (* args , ** kwargs )
258
+
259
+ return jax .vjp (fn_wrapper , * tensors )
260
+
261
+ def _jax_backward (vjp_spec , saved_tensors , grad_out ):
262
+ """JAX function to compute input gradients.
263
+
264
+ Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
265
+ """
266
+ from jax .tree_util import tree_unflatten
267
+ fun_vjp = tree_unflatten (vjp_spec , saved_tensors )
268
+ return fun_vjp (grad_out )
269
+
240
270
@wraps (fn )
241
271
def inner (* args , ** kwargs ):
242
272
from jax .tree_util import tree_flatten
@@ -290,36 +320,6 @@ def backward(ctx, *grad_out):
290
320
return inner
291
321
292
322
293
- # NOTE(qihqi): This function cannot be inlined from the callsite
294
- # Becuase if it does, then it won't hit the compilation cache for
295
- # call_jax. Call jax uses functions' id as key.
296
- def _jax_forward (fn , other , tree_def , tensors ):
297
- """JAX function to compute output and vjp function.
298
-
299
- primals should be a tuple (args, kwargs).
300
- """
301
- import jax
302
- from jax .tree_util import tree_flatten , tree_unflatten
303
-
304
- def fn_wrapper (* tensors ):
305
- # Reconstruct the original args and kwargs
306
- flat_inputs = util .merge (tensors , other )
307
- args , kwargs = tree_unflatten (tree_def , flat_inputs )
308
- return fn (* args , ** kwargs )
309
-
310
- return jax .vjp (fn_wrapper , * tensors )
311
-
312
-
313
- def _jax_backward (vjp_spec , saved_tensors , grad_out ):
314
- """JAX function to compute input gradients.
315
-
316
- Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
317
- """
318
- from jax .tree_util import tree_unflatten
319
- fun_vjp = tree_unflatten (vjp_spec , saved_tensors )
320
- return fun_vjp (grad_out )
321
-
322
-
323
323
fori_loop = torch_view (jax .lax .fori_loop )
324
324
325
325
0 commit comments