Skip to content

Commit 251b9ad

Browse files
authored
Type CategoricalIndex (#1459)
* wip * bit more * testing * more tests * fixup `insert` * include `np.integer`, as per upstream * type `categories` too * update pyproject.toml
1 parent 84f8d36 commit 251b9ad

File tree

5 files changed

+71
-34
lines changed

5 files changed

+71
-34
lines changed

pandas-stubs/core/indexes/base.pyi

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ from pandas._typing import (
6969
T_COMPLEX,
7070
AnyAll,
7171
AnyArrayLike,
72+
AnyArrayLikeInt,
7273
ArrayLike,
7374
AxesData,
7475
CategoryDtypeArg,
@@ -406,7 +407,7 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
406407
notnull = ...
407408
def fillna(self, value=...): ...
408409
def dropna(self, how: AnyAll = "any") -> Self: ...
409-
def unique(self, level=...) -> Self: ...
410+
def unique(self, level: Hashable | None = None) -> Self: ...
410411
def drop_duplicates(self, *, keep: DropKeep = ...) -> Self: ...
411412
def duplicated(self, keep: DropKeep = "first") -> np_1darray[np.bool]: ...
412413
def __and__(self, other: Never) -> Never: ...
@@ -442,12 +443,12 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
442443
) -> np_1darray[np.intp]: ...
443444
def reindex(
444445
self,
445-
target,
446-
method: ReindexMethod | None = ...,
447-
level=...,
448-
limit=...,
449-
tolerance=...,
450-
): ...
446+
target: Iterable[Any],
447+
method: ReindexMethod | None = None,
448+
level: int | None = None,
449+
limit: int | None = None,
450+
tolerance: Scalar | AnyArrayLike | Sequence[Scalar] | None = None,
451+
) -> tuple[Index, np_1darray[np.intp] | None]: ...
451452
@overload
452453
def join(
453454
self,
@@ -483,7 +484,7 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
483484
cond: Sequence[bool] | np_ndarray_bool | BooleanArray | IndexOpsMixin[bool],
484485
other: Scalar | AnyArrayLike | None = None,
485486
) -> Index: ...
486-
def __contains__(self, key) -> bool: ...
487+
def __contains__(self, key: Hashable) -> bool: ...
487488
@final
488489
def __setitem__(self, key, value) -> None: ...
489490
@overload
@@ -500,7 +501,7 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
500501
@overload
501502
def append(self, other: Index | Sequence[Index]) -> Index: ...
502503
def putmask(self, mask, value): ...
503-
def equals(self, other) -> bool: ...
504+
def equals(self, other: object) -> bool: ...
504505
@final
505506
def identical(self, other) -> bool: ...
506507
@final
@@ -534,8 +535,13 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
534535
def slice_locs(
535536
self, start: SliceType = None, end: SliceType = None, step: int | None = None
536537
): ...
537-
def delete(self, loc) -> Self: ...
538-
def insert(self, loc, item) -> Self: ...
538+
def delete(
539+
self, loc: np.integer | int | AnyArrayLikeInt | Sequence[int]
540+
) -> Self: ...
541+
@overload
542+
def insert(self, loc: int, item: S1) -> Self: ...
543+
@overload
544+
def insert(self, loc: int, item: object) -> Index: ...
539545
def drop(self, labels, errors: IgnoreRaise = "raise") -> Self: ...
540546
@property
541547
def shape(self) -> tuple[int, ...]: ...

pandas-stubs/core/indexes/category.pyi

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ from collections.abc import (
22
Hashable,
33
Iterable,
44
)
5-
from typing import final
65

76
import numpy as np
87
from pandas.core import accessor
@@ -11,7 +10,11 @@ from pandas.core.indexes.base import Index
1110
from pandas.core.indexes.extension import ExtensionIndex
1211
from typing_extensions import Self
1312

14-
from pandas._typing import S1
13+
from pandas._typing import (
14+
S1,
15+
Dtype,
16+
ListLike,
17+
)
1518

1619
class CategoricalIndex(ExtensionIndex[S1], accessor.PandasDelegate):
1720
codes: np.ndarray = ...
@@ -21,28 +24,20 @@ class CategoricalIndex(ExtensionIndex[S1], accessor.PandasDelegate):
2124
def __new__(
2225
cls,
2326
data: Iterable[S1] = ...,
24-
categories=...,
25-
ordered=...,
26-
dtype=...,
27-
copy: bool = ...,
28-
name: Hashable = ...,
27+
categories: ListLike | None = None,
28+
ordered: bool | None = None,
29+
dtype: Dtype | None = None,
30+
copy: bool = False,
31+
name: Hashable | None = None,
2932
) -> Self: ...
30-
def equals(self, other): ...
3133
@property
3234
def inferred_type(self) -> str: ...
3335
@property
34-
def values(self): ...
35-
def __contains__(self, key) -> bool: ...
36-
@property
3736
def is_unique(self) -> bool: ...
3837
@property
3938
def is_monotonic_increasing(self) -> bool: ...
4039
@property
4140
def is_monotonic_decreasing(self) -> bool: ...
42-
def unique(self, level=...): ...
43-
def reindex(self, target, method=..., level=..., limit=..., tolerance=...): ...
44-
@final
45-
def get_indexer(self, target, method=..., limit=..., tolerance=...): ...
46-
def get_indexer_non_unique(self, target): ...
47-
def delete(self, loc): ...
48-
def insert(self, loc, item): ...
41+
# `item` might be `S1` but not one of the categories, thus changing
42+
# the return type from `CategoricalIndex` to `Index`.
43+
def insert(self, loc: int, item: object) -> Index: ... # type: ignore[override]

pandas-stubs/core/indexes/range.pyi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ class RangeIndex(_IndexSubclassBase[int, np.int64]):
6161
def is_monotonic_decreasing(self) -> bool: ...
6262
@property
6363
def has_duplicates(self) -> bool: ...
64-
def __contains__(self, key: int | np.integer) -> bool: ...
6564
def factorize(
6665
self, sort: bool = False, use_na_sentinel: bool = True
6766
) -> tuple[np_1darray[np.intp], RangeIndex]: ...

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,6 @@ ignore = [
204204
"PYI042", # https://docs.astral.sh/ruff/rules/snake-case-type-alias/
205205
"ERA001", "PLR0402", "PLC0105"
206206
]
207-
"*category.pyi" = [
208-
# TODO: remove when pandas-dev/pandas-stubs#1443 is resolved
209-
"ANN001", "ANN201", "ANN204", "ANN206",
210-
]
211207
"*series.pyi" = [
212208
# TODO: remove when pandas-dev/pandas-stubs#1444 is resolved
213209
"ANN001", "ANN201", "ANN204", "ANN206",
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import pandas as pd
5+
from typing_extensions import (
6+
assert_type,
7+
)
8+
9+
from tests import (
10+
check,
11+
np_1darray,
12+
)
13+
14+
15+
def test_categoricalindex_unique() -> None:
16+
ci = pd.CategoricalIndex(["a", "b"])
17+
check(
18+
assert_type(ci.unique(), "pd.CategoricalIndex[str]"),
19+
pd.CategoricalIndex,
20+
)
21+
22+
23+
def test_categoricalindex_reindex() -> None:
24+
ci = pd.CategoricalIndex(["a", "b"])
25+
check(
26+
assert_type(ci.reindex([0, 1]), tuple[pd.Index, np_1darray[np.intp] | None]),
27+
tuple,
28+
)
29+
30+
31+
def test_categoricalindex_delete() -> None:
32+
ci = pd.CategoricalIndex(["a", "b"])
33+
check(assert_type(ci.delete(0), "pd.CategoricalIndex[str]"), pd.CategoricalIndex)
34+
check(
35+
assert_type(ci.delete([0, 1]), "pd.CategoricalIndex[str]"), pd.CategoricalIndex
36+
)
37+
38+
39+
def test_categoricalindex_insert() -> None:
40+
ci = pd.CategoricalIndex(["a", "b"])
41+
check(assert_type(ci.insert(0, "c"), pd.Index), pd.Index)

0 commit comments

Comments
 (0)