Skip to content

Conversation

@amacati
Copy link

@amacati amacati commented Dec 22, 2025

Jax released its version 0.8.2 https://github.com/jax-ml/jax/releases/tag/jax-v0.8.2, and it causes is_jax_array to no longer work for jit compiled arrays. This PR fixes the problem by also testing if the passed object is a jax.core.Tracer subclass.

Considerations

The current design requires one more check, but the check is compatible with caching and thus should be fast. I'm somewhat worried about correctness though. Are there tracers that are not jax Arrays? Maybe @jakevdp can weigh in on that question.

Related issue

#368

Copilot AI review requested due to automatic review settings December 22, 2025 21:00
@amacati amacati changed the title Fix is_jax_array for jax>=0.8.2 Fix is_jax_array for jax>=0.8.2 Dec 22, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a compatibility issue introduced in JAX version 0.8.2, where is_jax_array() stopped working correctly for JIT-compiled arrays. The fix adds a check for jax.core.Tracer subclasses, which is the type that JIT-compiled arrays become in JAX >= 0.8.2.

Key changes:

  • Extended is_jax_array() to detect jax.core.Tracer instances in addition to jax.Array
  • Added a test to verify the fix works with JIT-compiled functions

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
array_api_compat/common/_helpers.py Added _issubclass_fast(cls, "jax.core", "Tracer") check to is_jax_array() function to handle JIT-compiled arrays in JAX >= 0.8.2
tests/test_common.py Added test_is_jax_array_jitted() to verify is_jax_array() works correctly inside and outside JIT-compiled functions

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +238 to +242
return (
_issubclass_fast(cls, "jax", "Array")
or _issubclass_fast(cls, "jax.core", "Tracer")
or _is_jax_zero_gradient_array(x)
)
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for jax.core.Tracer has been added to is_jax_array, but several other helper functions in this file may also need similar updates for consistency and correctness. Specifically:

  1. _is_array_api_cls (line 302) - checks for jax.Array but not Tracer
  2. _cls_to_namespace (line 550) - checks for jax.Array but not Tracer
  3. _is_writeable_cls (line 940) - checks for jax.Array but not Tracer (JAX tracers should also be non-writeable)
  4. _is_lazy_cls (line 979) - checks for jax.Array but not Tracer (JAX tracers should also be lazy)

If is_jax_array now returns True for Tracers, these other functions should be updated to handle Tracers consistently. Otherwise, a jitted JAX array might pass is_jax_array but fail in array_namespace or behave incorrectly with is_writeable_array and is_lazy_array.

Copilot uses AI. Check for mistakes.

x = jnp.asarray([1, 2, 3])
assert is_jax_array(x)
assert jax.jit(lambda y: is_jax_array(y))(x)
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.

Suggested change
assert jax.jit(lambda y: is_jax_array(y))(x)
assert jax.jit(is_jax_array)(x)

Copilot uses AI. Check for mistakes.
return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)
return (
_issubclass_fast(cls, "jax", "Array")
or _issubclass_fast(cls, "jax.core", "Tracer")
Copy link
Contributor

@jakevdp jakevdp Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason for the change in v0.8.2 is that tracers now can represent more than just arrays, and so returning True for any tracer may lead to false positives.

The logic in Array.__instancecheck__ is what is required to accurately check in all contexts whether x is an array: https://github.com/jax-ml/jax/blob/82ae1b1cde42a5b93e00d8c3376cde627c2d83bb/jaxlib/py_array.cc#L2187-L2218

The easiest way to accomplish this would be to check isinstance(x, jax.Array) rather than recreating that logic here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will force us to use a non-cachable operation, which is going to slow things down. But I don't think we have a choice given that the Tracer type itself no longer holds information on whether or not it's an Array.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakevdp Can you elaborate a bit more on which kinds of non-array objects now create tracers? I.e. we use an _is_writable_cls and _is_lazy_cls. Even if tracers are not arrays, these functions could still be decidable based on the type only. Are tracers still always lazy and always immutable? I realize that these questions might be ill-defined since tracers do not represent real objects and can disappear from the final computation graph, but for our purposes that's not an issue.

Also, could you show an example of a tracer that does not wrap an array? E.g. are bools in the input now traced as bools and not as arrays? This would be very helpful for testing.

@crusaderky Current helper methods such as _is_writable_cls are designed to return None for non-array API objects. It seems we cannot make that decision based off of type information only on jax>=0.8.2. Are you fine with relaxing the None strategy and returning True for Tracers in general, or do you want to be strict here? The former still fits into our current setup, the latter must use non-cachable isinstance checks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could you show an example of a tracer that does not wrap an array?

An example is the new hijax Box type. There are no public APIs for this (yet), but here's how you can construct it using currently-private APIs at head:

import jax
from jax._src import hijax

box = hijax.new_box()
hijax.box_set(box, (jnp.arange(4), jnp.ones((3, 3)), 2.0, None))

@jax.jit
def f(box):
  print(type(box))  # <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
  print(box.aval)  # BoxTy()
  print(hijax.box_get(box))  # (JitTracer(int32[4]), JitTracer(float32[3,3]), JitTracer(~float32[]), None)
  # print(box.dtype)  # fails with AttributeError
  # print(box.shape)  # fails with AttributeError

f(box)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current design is that Tracer subclass reflects the type of transformation being traced (e.g. jit, vmap, grad, jaxpr, etc.) while the aval attribute can be inspected to see what kind of object is being traced.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's very helpful. At this point I think we need a decision by the array-api-compat team. Both versions shouldn't be hard to implement.

@crusaderky @lucascolley what are your thoughts?

@lucascolley lucascolley self-requested a review December 23, 2025 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants