Describe the bug
Compiling what's provided in the reproduction leads to the following:
F external/xla/xla/permutation_util.h:47] Check failed: permutation.size() == data.size() (1 vs. 2)
Note that this only happens when using multiple cores (tested with 2), forcing NEURON_RT_VISIBLE_CORES=0 works. It also works on CPU.
Model Name
N/A
Describe the workload type
N/A
Instance Type
inf2.8xlarge
Release version
aws-neuronx-collectives/unknown,now 2.31.24.0-1a31ba186 amd64
aws-neuronx-dkms/unknown,now 2.27.4.0 all
aws-neuronx-runtime-lib/unknown,now 2.31.24.0-0b044f4ce amd64
aws-neuronx-tools/unknown,now 2.29.18.0-d5fe7ba42 amd64
Using Python 3.13.13
jax 0.6.2
jax-neuronx 0.7.0.1.0.8181+1e892be0
jaxlib 0.6.2
libneuronxla 2.2.16408.0+50c26cbd
neuronx-cc 2.23.6484.0+3b612583
Although I have tested these version without Jax and got the same crash on:
aws-neuronx-runtime-lib 2.29.40.0-f954cd7a5
aws-neuronx-collectives 2.29.41.0-681fef5f5
libneuronxla 2.2.16408.0+50c26cbd
neuronx-cc 2.24.5133.0+58f8de22
Reproduction Steps
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec
devices = jax.devices("neuron")[:2]
mesh = Mesh(np.array(devices), ("m",))
replicated = NamedSharding(mesh, PartitionSpec())
@jax.jit
def f_gather(src):
idx = jnp.zeros((2, 1), dtype=jnp.int32)
return lax.gather(
src,
idx,
lax.GatherDimensionNumbers(
offset_dims=(),
collapsed_slice_dims=(0,),
start_index_map=(0,),
operand_batching_dims=(1,),
start_indices_batching_dims=(0,),
),
slice_sizes=(1, 1),
)
@jax.jit
def f_take_along_axis(src):
idx = jnp.zeros((src.shape[1],), dtype=jnp.int32)
return jnp.take_along_axis(src, idx[None, :], axis=0)
src = jax.device_put(jnp.array([[10, 20], [30, 40]], dtype=jnp.int32), replicated)
# fn = f_gather
fn = f_take_along_axis
print(fn.lower(src).as_text())
print(fn(src))
Both of these functions crash
Regression Issue
Possible Solution
No response
Logs/Context/Additional Information
The first time I encountered this issue was with the following code, without JAX, on the latest version mentioned:
module @zml attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x2xi32> {mhlo.sharding = "{replicated}"}) -> (tensor<2xi32> {mhlo.sharding = "{replicated}"}) {
%c = stablehlo.constant dense<0> : tensor<i32>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32>
%1 = stablehlo.reshape %0 : (tensor<2xi32>) -> tensor<2x1xi32>
%2 = "stablehlo.gather"(%arg0, %1) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], operand_batching_dims = [1], start_indices_batching_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<2x2xi32>, tensor<2x1xi32>) -> tensor<2xi32>
return %2 : tensor<2xi32>
}
}
the gather looks similar to the one generated by Jax:
module @jit_f_gather attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x2xi32> {mhlo.sharding = "{replicated}"}) -> (tensor<2xi32> {jax.result_info = "result"}) {
%c = stablehlo.constant dense<0> : tensor<i32>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x1xi32>
%1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], operand_batching_dims = [1], start_indices_batching_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<2x2xi32>, tensor<2x1xi32>) -> tensor<2xi32>
return %1 : tensor<2xi32>
}
}
Here are the logs that I got if needed
Describe the bug
Compiling what's provided in the reproduction leads to the following:
F external/xla/xla/permutation_util.h:47] Check failed: permutation.size() == data.size() (1 vs. 2)Note that this only happens when using multiple cores (tested with 2), forcing
NEURON_RT_VISIBLE_CORES=0works. It also works on CPU.Model Name
N/A
Describe the workload type
N/A
Instance Type
inf2.8xlarge
Release version
Although I have tested these version without Jax and got the same crash on:
Reproduction Steps
Both of these functions crash
Regression Issue
Possible Solution
No response
Logs/Context/Additional Information
The first time I encountered this issue was with the following code, without JAX, on the latest version mentioned:
the
gatherlooks similar to the one generated by Jax:Here are the logs that I got if needed