-
Notifications
You must be signed in to change notification settings - Fork 40
Fix is_jax_array for jax>=0.8.2
#369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
is_jax_array for jax>=0.8.2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes a compatibility issue introduced in JAX version 0.8.2, where is_jax_array() stopped working correctly for JIT-compiled arrays. The fix adds a check for jax.core.Tracer subclasses, which is the type that JIT-compiled arrays become in JAX >= 0.8.2.
Key changes:
- Extended
is_jax_array()to detectjax.core.Tracerinstances in addition tojax.Array - Added a test to verify the fix works with JIT-compiled functions
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
array_api_compat/common/_helpers.py |
Added _issubclass_fast(cls, "jax.core", "Tracer") check to is_jax_array() function to handle JIT-compiled arrays in JAX >= 0.8.2 |
tests/test_common.py |
Added test_is_jax_array_jitted() to verify is_jax_array() works correctly inside and outside JIT-compiled functions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| return ( | ||
| _issubclass_fast(cls, "jax", "Array") | ||
| or _issubclass_fast(cls, "jax.core", "Tracer") | ||
| or _is_jax_zero_gradient_array(x) | ||
| ) |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for jax.core.Tracer has been added to is_jax_array, but several other helper functions in this file may also need similar updates for consistency and correctness. Specifically:
_is_array_api_cls(line 302) - checks forjax.Arraybut notTracer_cls_to_namespace(line 550) - checks forjax.Arraybut notTracer_is_writeable_cls(line 940) - checks forjax.Arraybut notTracer(JAX tracers should also be non-writeable)_is_lazy_cls(line 979) - checks forjax.Arraybut notTracer(JAX tracers should also be lazy)
If is_jax_array now returns True for Tracers, these other functions should be updated to handle Tracers consistently. Otherwise, a jitted JAX array might pass is_jax_array but fail in array_namespace or behave incorrectly with is_writeable_array and is_lazy_array.
|
|
||
| x = jnp.asarray([1, 2, 3]) | ||
| assert is_jax_array(x) | ||
| assert jax.jit(lambda y: is_jax_array(y))(x) |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
| assert jax.jit(lambda y: is_jax_array(y))(x) | |
| assert jax.jit(is_jax_array)(x) |
| return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) | ||
| return ( | ||
| _issubclass_fast(cls, "jax", "Array") | ||
| or _issubclass_fast(cls, "jax.core", "Tracer") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason for the change in v0.8.2 is that tracers now can represent more than just arrays, and so returning True for any tracer may lead to false positives.
The logic in Array.__instancecheck__ is what is required to accurately check in all contexts whether x is an array: https://github.com/jax-ml/jax/blob/82ae1b1cde42a5b93e00d8c3376cde627c2d83bb/jaxlib/py_array.cc#L2187-L2218
The easiest way to accomplish this would be to check isinstance(x, jax.Array) rather than recreating that logic here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That will force us to use a non-cachable operation, which is going to slow things down. But I don't think we have a choice given that the Tracer type itself no longer holds information on whether or not it's an Array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jakevdp Can you elaborate a bit more on which kinds of non-array objects now create tracers? I.e. we use an _is_writable_cls and _is_lazy_cls. Even if tracers are not arrays, these functions could still be decidable based on the type only. Are tracers still always lazy and always immutable? I realize that these questions might be ill-defined since tracers do not represent real objects and can disappear from the final computation graph, but for our purposes that's not an issue.
Also, could you show an example of a tracer that does not wrap an array? E.g. are bools in the input now traced as bools and not as arrays? This would be very helpful for testing.
@crusaderky Current helper methods such as _is_writable_cls are designed to return None for non-array API objects. It seems we cannot make that decision based off of type information only on jax>=0.8.2. Are you fine with relaxing the None strategy and returning True for Tracers in general, or do you want to be strict here? The former still fits into our current setup, the latter must use non-cachable isinstance checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, could you show an example of a tracer that does not wrap an array?
An example is the new hijax Box type. There are no public APIs for this (yet), but here's how you can construct it using currently-private APIs at head:
import jax
from jax._src import hijax
box = hijax.new_box()
hijax.box_set(box, (jnp.arange(4), jnp.ones((3, 3)), 2.0, None))
@jax.jit
def f(box):
print(type(box)) # <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
print(box.aval) # BoxTy()
print(hijax.box_get(box)) # (JitTracer(int32[4]), JitTracer(float32[3,3]), JitTracer(~float32[]), None)
# print(box.dtype) # fails with AttributeError
# print(box.shape) # fails with AttributeError
f(box)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current design is that Tracer subclass reflects the type of transformation being traced (e.g. jit, vmap, grad, jaxpr, etc.) while the aval attribute can be inspected to see what kind of object is being traced.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that's very helpful. At this point I think we need a decision by the array-api-compat team. Both versions shouldn't be hard to implement.
@crusaderky @lucascolley what are your thoughts?
Jax released its version 0.8.2 https://github.com/jax-ml/jax/releases/tag/jax-v0.8.2, and it causes
is_jax_arrayto no longer work for jit compiled arrays. This PR fixes the problem by also testing if the passed object is ajax.core.Tracersubclass.Considerations
The current design requires one more check, but the check is compatible with caching and thus should be fast. I'm somewhat worried about correctness though. Are there tracers that are not jax Arrays? Maybe @jakevdp can weigh in on that question.
Related issue
#368