|  | 
| 19 | 19 | from pandas.core.arrays.timedeltas import TimedeltaArray | 
| 20 | 20 | from pandas.core.indexes.base import Index | 
| 21 | 21 | from pandas.core.indexes.category import CategoricalIndex | 
|  | 22 | +from pandas.core.indexes.datetimes import DatetimeIndex | 
| 22 | 23 | from typing_extensions import ( | 
| 23 | 24 |     Never, | 
| 24 | 25 |     assert_type, | 
| @@ -1541,3 +1542,39 @@ def test_multiindex_swaplevel() -> None: | 
| 1541 | 1542 |     """Test that MultiIndex.swaplevel returns MultiIndex""" | 
| 1542 | 1543 |     mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"]) | 
| 1543 | 1544 |     check(assert_type(mi.swaplevel(0, 1), "pd.MultiIndex"), pd.MultiIndex) | 
|  | 1545 | + | 
|  | 1546 | + | 
|  | 1547 | +def test_index_where() -> None: | 
|  | 1548 | +    """Test Index.where with multiple types of other GH1419.""" | 
|  | 1549 | +    idx = pd.Index(range(48)) | 
|  | 1550 | +    mask = np.ones(48, dtype=bool) | 
|  | 1551 | +    val_idx = idx.where(mask, idx) | 
|  | 1552 | +    check(assert_type(val_idx, "pd.Index[int]"), pd.Index, int) | 
|  | 1553 | + | 
|  | 1554 | +    val_sr = idx.where(mask, (idx).to_series()) | 
|  | 1555 | +    check(assert_type(val_sr, "pd.Index[int]"), pd.Index, int) | 
|  | 1556 | + | 
|  | 1557 | + | 
|  | 1558 | +def test_datetimeindex_where() -> None: | 
|  | 1559 | +    """Test DatetimeIndex.where with multiple types of other GH1419.""" | 
|  | 1560 | +    datetime_index = pd.date_range(start="2025-01-01", freq="h", periods=48) | 
|  | 1561 | +    mask = np.ones(48, dtype=bool) | 
|  | 1562 | +    val_idx = datetime_index.where(mask, datetime_index - pd.Timedelta(days=1)) | 
|  | 1563 | +    check(assert_type(val_idx, DatetimeIndex), DatetimeIndex) | 
|  | 1564 | + | 
|  | 1565 | +    val_sr = datetime_index.where( | 
|  | 1566 | +        mask, (datetime_index - pd.Timedelta(days=1)).to_series() | 
|  | 1567 | +    ) | 
|  | 1568 | +    check(assert_type(val_sr, DatetimeIndex), DatetimeIndex) | 
|  | 1569 | + | 
|  | 1570 | +    val_idx_scalar = datetime_index.where(mask, pd.Index([0, 1])) | 
|  | 1571 | +    check(assert_type(val_idx_scalar, pd.Index), pd.Index) | 
|  | 1572 | + | 
|  | 1573 | +    val_sr_scalar = datetime_index.where(mask, pd.Series([0, 1])) | 
|  | 1574 | +    check(assert_type(val_sr_scalar, pd.Index), pd.Index) | 
|  | 1575 | + | 
|  | 1576 | +    val_scalar = datetime_index.where(mask, 1) | 
|  | 1577 | +    check(assert_type(val_scalar, pd.Index), pd.Index) | 
|  | 1578 | + | 
|  | 1579 | +    val_range = pd.RangeIndex(2).where(pd.Series([True, False]), 3) | 
|  | 1580 | +    check(assert_type(val_range, pd.Index), pd.RangeIndex) | 
0 commit comments