Skip to content

Commit 5f881bf

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent f7ab683 commit 5f881bf

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

jax/_src/numpy/reductions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,9 +2383,9 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
23832383
raise ValueError("jax.numpy.quantile does not support overwrite_input=True "
23842384
"or out != None")
23852385
if not isinstance(interpolation, DeprecatedArg):
2386-
raise TypeError("nanquantile() argument interpolation was removed in JAX"
2386+
raise TypeError("quantile() argument interpolation was removed in JAX"
23872387
" v0.8.0. Use method instead.")
2388-
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False, weights)
2388+
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False, weights)
23892389

23902390
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
23912391
@export
@@ -2442,7 +2442,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24422442
if not isinstance(interpolation, DeprecatedArg):
24432443
raise TypeError("nanquantile() argument interpolation was removed in JAX"
24442444
" v0.8.0. Use method instead.")
2445-
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True, weights)
2445+
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True, weights)
24462446

24472447
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24482448
method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array:

0 commit comments

Comments
 (0)