Skip to content

Commit c7d6823

Browse files
GH1419 Allow Series and Index for other in Index.where(..., other)
1 parent 143bab4 commit c7d6823

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

pandas-stubs/core/indexes/base.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ from pandas._typing import (
6767
S2_NSDT,
6868
T_COMPLEX,
6969
AnyAll,
70+
AnyArrayLike,
7071
ArrayLike,
7172
AxesData,
7273
CategoryDtypeArg,
@@ -440,7 +441,7 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
440441
@property
441442
def values(self) -> np_1darray: ...
442443
def memory_usage(self, deep: bool = False): ...
443-
def where(self, cond, other: Scalar | ArrayLike | None = None): ...
444+
def where(self, cond, other: Scalar | AnyArrayLike | None = None) -> Self: ...
444445
def __contains__(self, key) -> bool: ...
445446
@final
446447
def __setitem__(self, key, value) -> None: ...

tests/indexes/test_indexes.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pandas.core.arrays.timedeltas import TimedeltaArray
2020
from pandas.core.indexes.base import Index
2121
from pandas.core.indexes.category import CategoricalIndex
22+
from pandas.core.indexes.datetimes import DatetimeIndex
2223
from typing_extensions import (
2324
Never,
2425
assert_type,
@@ -1541,3 +1542,27 @@ def test_multiindex_swaplevel() -> None:
15411542
"""Test that MultiIndex.swaplevel returns MultiIndex"""
15421543
mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"])
15431544
check(assert_type(mi.swaplevel(0, 1), "pd.MultiIndex"), pd.MultiIndex)
1545+
1546+
1547+
def test_index_where() -> None:
1548+
"""Test Index.where with multiple types of other GH1419."""
1549+
idx = pd.Index(range(48))
1550+
mask = np.ones(48, dtype=bool)
1551+
val_idx = idx.where(mask, idx)
1552+
check(assert_type(val_idx, "pd.Index[int]"), pd.Index, int)
1553+
1554+
val_sr = idx.where(mask, (idx).to_series())
1555+
check(assert_type(val_sr, "pd.Index[int]"), pd.Index, int)
1556+
1557+
1558+
def test_datetimeindex_where() -> None:
1559+
"""Test DatetimeIndex.where with multiple types of other GH1419."""
1560+
datetime_index = pd.date_range(start="2025-01-01", freq="h", periods=48)
1561+
mask = np.ones(48, dtype=bool)
1562+
val_idx = datetime_index.where(mask, datetime_index - pd.Timedelta(days=1))
1563+
check(assert_type(val_idx, DatetimeIndex), DatetimeIndex)
1564+
1565+
val_sr = datetime_index.where(
1566+
mask, (datetime_index - pd.Timedelta(days=1)).to_series()
1567+
)
1568+
check(assert_type(val_sr, DatetimeIndex), DatetimeIndex)

0 commit comments

Comments
 (0)