@@ -435,8 +435,14 @@ def fall(self):
435435 _ , closed_over_himutables = pe .convert_const_himutables (hi_jaxpr )
436436 if closed_over_himutables : raise NotImplementedError # TODO(mattjj)
437437 lo_jaxpr = pe .lower_jaxpr (hi_jaxpr )
438- in_tree = lojax_pytree (hi_jaxpr .in_aval_qdds , self ._in_tree )
439- out_tree = lojax_pytree (hi_jaxpr .out_avals , self .out_tree )
438+ if any (a .is_high for a in hi_jaxpr .final_aval_qdds ):
439+ in_tree = lojax_pytree (hi_jaxpr .in_aval_qdds , self ._in_tree )
440+ else :
441+ in_tree = self ._in_tree
442+ if any (a .is_high for a in hi_jaxpr .out_avals ):
443+ out_tree = lojax_pytree (hi_jaxpr .out_avals , self .out_tree )
444+ else :
445+ out_tree = self .out_tree
440446 params = dict (lojax_expand_params (hi_jaxpr , self ._params ), jaxpr = lo_jaxpr )
441447 lo_meta_tys = [mty .replace (aval = lo_ty )
442448 for mty , aq in zip (self ._meta_tys_flat , hi_jaxpr .in_aval_qdds )
@@ -791,7 +797,7 @@ def call(*args, **kwargs):
791797 "function was compiled by a transformation that does not support "
792798 f"keyword arguments, but called with keyword arguments: { kws } " )
793799
794- if params .is_high :
800+ if params .in_types and any ( a . is_high for a in params . in_types [ 1 ]) :
795801 hi_args_flat , in_hi_tree = tree_util .tree_flatten ((args , kwargs ))
796802 in_hi_tree_ , final_qdds = params .in_types
797803 args_flat = [a .read_loval (core .cur_qdd (x ), x ) if (a := typeof (x )).has_qdd
@@ -828,7 +834,7 @@ def call(*args, **kwargs):
828834 f"Tracer type { type (arg )} ." )
829835 lo_outs = params .executable .call (* params .const_args , * args_flat )
830836
831- if params .is_high :
837+ if params .out_types and any ( a . is_high for a in params . out_types [ 1 ]) :
832838 out_mut , lo_outs = util .split_list (lo_outs , [_num_himuts_out (final_qdds )])
833839 _apply_himut (final_qdds , hi_args_flat , out_mut )
834840 out_hi_tree , out_hi_types = params .out_types
0 commit comments