From de564a318f3337ba5991e4d57d00b800886f6e02 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 7 Nov 2025 00:12:29 -0800 Subject: [PATCH] Fix Python path for executing "hi" JAX code. PiperOrigin-RevId: 829303995 --- jax/_src/stages.py | 10 ++++++++-- tests/hijax_test.py | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 779744b4ca99..55475a0f34dd 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -435,8 +435,14 @@ def fall(self): _, closed_over_himutables = pe.convert_const_himutables(hi_jaxpr) if closed_over_himutables: raise NotImplementedError # TODO(mattjj) lo_jaxpr = pe.lower_jaxpr(hi_jaxpr) - in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree) - out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree) + if any(a.is_high for a in hi_jaxpr.final_aval_qdds): + in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree) + else: + in_tree = self._in_tree + if any(a.is_high for a in hi_jaxpr.out_avals): + out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree) + else: + out_tree = self.out_tree params = dict(lojax_expand_params(hi_jaxpr, self._params), jaxpr=lo_jaxpr) lo_meta_tys = [mty.replace(aval=lo_ty) for mty, aq in zip(self._meta_tys_flat, hi_jaxpr.in_aval_qdds) diff --git a/tests/hijax_test.py b/tests/hijax_test.py index b0609591dce7..4dcf49f74ce5 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -534,6 +534,7 @@ def f(x): if jit: f = jax.jit(f) + self.assertEqual(f.lower(2.0).compile()(2.0), 8.0) self.assertEqual(f(2.0), 8.0) xs = jnp.arange(3.0)