Skip to content

Commit 86c94eb

Browse files
GH1419 Allow Series and Index for other in Index.where(..., other)
1 parent 1cdecd9 commit 86c94eb

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

pandas-stubs/core/indexes/base.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ from pandas._typing import (
6161
T_COMPLEX,
6262
T_INT,
6363
AnyAll,
64-
ArrayLike,
64+
AnyArrayLike,
6565
AxesData,
6666
CategoryDtypeArg,
6767
DropKeep,
@@ -434,7 +434,7 @@ class Index(IndexOpsMixin[S1]):
434434
@property
435435
def values(self) -> np_1darray: ...
436436
def memory_usage(self, deep: bool = False): ...
437-
def where(self, cond, other: Scalar | ArrayLike | None = None): ...
437+
def where(self, cond, other: Scalar | AnyArrayLike | None = None) -> Self: ...
438438
def __contains__(self, key) -> bool: ...
439439
@final
440440
def __setitem__(self, key, value) -> None: ...

tests/indexes/test_indexes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pandas.core.arrays.timedeltas import TimedeltaArray
2020
from pandas.core.indexes.base import Index
2121
from pandas.core.indexes.category import CategoricalIndex
22+
from pandas.core.indexes.datetimes import DatetimeIndex
2223
from typing_extensions import (
2324
Never,
2425
assert_type,
@@ -1608,3 +1609,18 @@ def test_to_series() -> None:
16081609
np.complexfloating,
16091610
)
16101611
check(assert_type(Index(["1"]).to_series(), "pd.Series[str]"), pd.Series, str)
1612+
1613+
1614+
def test_index_where() -> None:
1615+
"""Test Index.where with multiple types of other GH1419."""
1616+
datetime_index = pd.DatetimeIndex(
1617+
pd.date_range(start="2025-01-01", freq="h", periods=48)
1618+
)
1619+
mask = np.ones(48, dtype=bool)
1620+
val_idx = datetime_index.where(mask, datetime_index - pd.Timedelta(days=1))
1621+
check(assert_type(val_idx, DatetimeIndex), DatetimeIndex)
1622+
1623+
val_sr = datetime_index.where(
1624+
mask, (datetime_index - pd.Timedelta(days=1)).to_series()
1625+
)
1626+
check(assert_type(val_sr, DatetimeIndex), DatetimeIndex)

0 commit comments

Comments
 (0)