Skip to content

Commit 0627c84

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Fix mishandling of large matrices in batched eigendecomposition.
Fixes #33062 PiperOrigin-RevId: 830926150
1 parent 7bb44b6 commit 0627c84

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4343
decomposition on CUDA GPUs. This is also an alias for the existing algorithm
4444
on TPUs.
4545
46+
* Bug fixes:
47+
48+
* Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on
49+
GPU (({jax-issue}`#33062`).
50+
4651
* Deprecations:
4752
* Default `axis_types` of `jax.make_mesh` will change in JAX v0.9.0 to return
4853
`jax.sharding.AxisType.Explicit`. Leaving axis_types unspecified will raise a

jaxlib/gpu/solver_kernels_ffi.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,10 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
537537
int64_t batch_step = 1;
538538
FFI_ASSIGN_OR_RETURN(bool is_batched_syev_supported,
539539
IsSyevBatchedSupported());
540-
if (is_batched_syev_supported) {
540+
if (is_batched_syev_supported && n > 0) {
541541
int64_t matrix_size = n * n * ffi::ByteWidth(dataType);
542-
batch_step = std::numeric_limits<int>::max() / matrix_size;
542+
batch_step =
543+
std::max(int64_t(1), std::numeric_limits<int>::max() / matrix_size);
543544
if (batch_step >= 32 * 1024) {
544545
batch_step = 32 * 1024;
545546
}
@@ -585,7 +586,7 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
585586

586587
for (int64_t i = 0; i < batch; i += batch_step) {
587588
size_t batch_size = static_cast<size_t>(std::min(batch_step, batch - i));
588-
if (is_batched_syev_supported) {
589+
if (is_batched_syev_supported && batch_step > 1) {
589590
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevBatched(
590591
handle.get(), params, jobz, uplo, n, aType, out_data, n, wType,
591592
w_data, aType, workspaceOnDevice, workspaceInBytesOnDevice,

tests/linalg_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from functools import partial
1616
import itertools
1717
from collections.abc import Iterator
18+
import unittest
1819

1920
import numpy as np
2021
import scipy
@@ -457,6 +458,15 @@ def testEigh(self, n, dtype, lower):
457458
w_np.astype(w.dtype), w, atol=tol * np.linalg.norm(a), rtol=tol
458459
)
459460

461+
@jax._src.config.explicit_x64_dtypes("allow")
462+
@jtu.run_on_devices("gpu")
463+
@unittest.skip("Needs a large amount of GPU memory, doesn't work in CI")
464+
def testEighLargeMatrix(self):
465+
# https://github.com/jax-ml/jax/issues/33062
466+
n = 16384
467+
A = jnp.eye(n, dtype=jnp.float64)
468+
jax.block_until_ready(jax.lax.linalg.eigh(A))
469+
460470
@jtu.sample_product(
461471
start=[0, 1, 63, 64, 65, 255],
462472
end=[1, 63, 64, 65, 256],

0 commit comments

Comments
 (0)