Skip to content

Commit da822d2

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
Fix Python path for executing "hi" JAX code.
PiperOrigin-RevId: 829303995
1 parent 4cdb3ce commit da822d2

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

jax/_src/stages.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,14 @@ def fall(self):
435435
_, closed_over_himutables = pe.convert_const_himutables(hi_jaxpr)
436436
if closed_over_himutables: raise NotImplementedError # TODO(mattjj)
437437
lo_jaxpr = pe.lower_jaxpr(hi_jaxpr)
438-
in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree)
439-
out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree)
438+
if any(a.is_high for a in hi_jaxpr.final_aval_qdds):
439+
in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree)
440+
else:
441+
in_tree = self._in_tree
442+
if any(a.is_high for a in hi_jaxpr.out_avals):
443+
out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree)
444+
else:
445+
out_tree = self.out_tree
440446
params = dict(lojax_expand_params(hi_jaxpr, self._params), jaxpr=lo_jaxpr)
441447
lo_meta_tys = [mty.replace(aval=lo_ty)
442448
for mty, aq in zip(self._meta_tys_flat, hi_jaxpr.in_aval_qdds)
@@ -791,7 +797,7 @@ def call(*args, **kwargs):
791797
"function was compiled by a transformation that does not support "
792798
f"keyword arguments, but called with keyword arguments: {kws}")
793799

794-
if params.is_high:
800+
if params.in_types and any(a.is_high for a in params.in_types[1]):
795801
hi_args_flat, in_hi_tree = tree_util.tree_flatten((args, kwargs))
796802
in_hi_tree_, final_qdds = params.in_types
797803
args_flat = [a.read_loval(core.cur_qdd(x), x) if (a := typeof(x)).has_qdd
@@ -828,7 +834,7 @@ def call(*args, **kwargs):
828834
f"Tracer type {type(arg)}.")
829835
lo_outs = params.executable.call(*params.const_args, *args_flat)
830836

831-
if params.is_high:
837+
if params.out_types and any(a.is_high for a in params.out_types[1]):
832838
out_mut, lo_outs = util.split_list(lo_outs, [_num_himuts_out(final_qdds)])
833839
_apply_himut(final_qdds, hi_args_flat, out_mut)
834840
out_hi_tree, out_hi_types = params.out_types

tests/hijax_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def f(x):
534534

535535
if jit:
536536
f = jax.jit(f)
537+
self.assertEqual(f.lower(2.0).compile()(2.0), 8.0)
537538

538539
self.assertEqual(f(2.0), 8.0)
539540
xs = jnp.arange(3.0)

0 commit comments

Comments
 (0)