Skip to content

XLA crash on stablehlo.gather on multi-core Neuron #1309

@elogir

Description

@elogir

Describe the bug

Compiling what's provided in the reproduction leads to the following:
F external/xla/xla/permutation_util.h:47] Check failed: permutation.size() == data.size() (1 vs. 2)

Note that this only happens when using multiple cores (tested with 2), forcing NEURON_RT_VISIBLE_CORES=0 works. It also works on CPU.

Model Name

N/A

Describe the workload type

N/A

Instance Type

inf2.8xlarge

Release version

aws-neuronx-collectives/unknown,now 2.31.24.0-1a31ba186 amd64
aws-neuronx-dkms/unknown,now 2.27.4.0 all
aws-neuronx-runtime-lib/unknown,now 2.31.24.0-0b044f4ce amd64
aws-neuronx-tools/unknown,now 2.29.18.0-d5fe7ba42 amd64

Using Python 3.13.13
jax                        0.6.2
jax-neuronx         0.7.0.1.0.8181+1e892be0
jaxlib                    0.6.2
libneuronxla        2.2.16408.0+50c26cbd
neuronx-cc         2.23.6484.0+3b612583

Although I have tested these version without Jax and got the same crash on:

aws-neuronx-runtime-lib 2.29.40.0-f954cd7a5
aws-neuronx-collectives 2.29.41.0-681fef5f5
libneuronxla 2.2.16408.0+50c26cbd
neuronx-cc 2.24.5133.0+58f8de22

Reproduction Steps

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec

devices = jax.devices("neuron")[:2]
mesh = Mesh(np.array(devices), ("m",))
replicated = NamedSharding(mesh, PartitionSpec())


@jax.jit
def f_gather(src):
    idx = jnp.zeros((2, 1), dtype=jnp.int32)
    return lax.gather(
        src,
        idx,
        lax.GatherDimensionNumbers(
            offset_dims=(),
            collapsed_slice_dims=(0,),
            start_index_map=(0,),
            operand_batching_dims=(1,),
            start_indices_batching_dims=(0,),
        ),
        slice_sizes=(1, 1),
    )


@jax.jit
def f_take_along_axis(src):
    idx = jnp.zeros((src.shape[1],), dtype=jnp.int32)
    return jnp.take_along_axis(src, idx[None, :], axis=0)


src = jax.device_put(jnp.array([[10, 20], [30, 40]], dtype=jnp.int32), replicated)
# fn = f_gather
fn = f_take_along_axis

print(fn.lower(src).as_text())
print(fn(src))

Both of these functions crash

Regression Issue

  • Select this option if this issue appears to be a regression.

Possible Solution

No response

Logs/Context/Additional Information

The first time I encountered this issue was with the following code, without JAX, on the latest version mentioned:

module @zml attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x2xi32> {mhlo.sharding = "{replicated}"}) -> (tensor<2xi32> {mhlo.sharding = "{replicated}"}) {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32>
    %1 = stablehlo.reshape %0 : (tensor<2xi32>) -> tensor<2x1xi32>
    %2 = "stablehlo.gather"(%arg0, %1) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], operand_batching_dims = [1], start_indices_batching_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<2x2xi32>, tensor<2x1xi32>) -> tensor<2xi32>
    return %2 : tensor<2xi32>
  }
}

the gather looks similar to the one generated by Jax:

module @jit_f_gather attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x2xi32> {mhlo.sharding = "{replicated}"}) -> (tensor<2xi32> {jax.result_info = "result"}) {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x1xi32>
    %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], operand_batching_dims = [1], start_indices_batching_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<2x2xi32>, tensor<2x1xi32>) -> tensor<2xi32>
    return %1 : tensor<2xi32>
  }
}

Here are the logs that I got if needed

Metadata

Metadata

Assignees

No one assigned

    Labels

    Inf2bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions