diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 779744b4ca99..55475a0f34dd 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -435,8 +435,14 @@ def fall(self): _, closed_over_himutables = pe.convert_const_himutables(hi_jaxpr) if closed_over_himutables: raise NotImplementedError # TODO(mattjj) lo_jaxpr = pe.lower_jaxpr(hi_jaxpr) - in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree) - out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree) + if any(a.is_high for a in hi_jaxpr.final_aval_qdds): + in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree) + else: + in_tree = self._in_tree + if any(a.is_high for a in hi_jaxpr.out_avals): + out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree) + else: + out_tree = self.out_tree params = dict(lojax_expand_params(hi_jaxpr, self._params), jaxpr=lo_jaxpr) lo_meta_tys = [mty.replace(aval=lo_ty) for mty, aq in zip(self._meta_tys_flat, hi_jaxpr.in_aval_qdds) diff --git a/tests/hijax_test.py b/tests/hijax_test.py index b0609591dce7..4dcf49f74ce5 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -534,6 +534,7 @@ def f(x): if jit: f = jax.jit(f) + self.assertEqual(f.lower(2.0).compile()(2.0), 8.0) self.assertEqual(f(2.0), 8.0) xs = jnp.arange(3.0)