Skip to content

Commit 3add079

Browse files
authored
🎨 workaround for path-dependent pyright bug by shuffling disjoint overloads (#745)
2 parents ccf6ae5 + efd1592 commit 3add079

File tree

3 files changed

+145
-141
lines changed

3 files changed

+145
-141
lines changed

‎scipy-stubs/cluster/vq.pyi‎

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Callable
1+
from collections.abc import Callable, Sequence
22
from types import ModuleType
33
from typing import Final, Literal, TypeAlias, overload
44
from typing_extensions import TypeVar
@@ -9,12 +9,17 @@ import optype.numpy.compat as npc
99

1010
__all__ = ["kmeans", "kmeans2", "vq", "whiten"]
1111

12+
_InexactT = TypeVar("_InexactT", bound=npc.inexact)
13+
1214
_InitMethod: TypeAlias = Literal["random", "points", "++", "matrix"]
1315
_MissingMethod: TypeAlias = Literal["warn", "raise"]
1416

15-
_InexactT = TypeVar("_InexactT", bound=npc.inexact)
17+
_ToFloat32_2D: TypeAlias = onp.ToArray2D[int, np.float32 | np.float16 | npc.integer16 | npc.integer8]
18+
_AsFloat64_2D: TypeAlias = onp.ToArray2D[float, npc.floating64 | npc.integer]
19+
_PyFloatMax2D: TypeAlias = Sequence[float] | Sequence[Sequence[float]]
1620

1721
###
22+
# NOTE: DO NOT RE-ORDER THE OVERLOADS WITH a `# type: ignore`, otherwise it'll trigger a pernicious bug in pyright (1.1.403).
1823

1924
class ClusterError(Exception): ...
2025

@@ -25,14 +30,14 @@ def whiten(obs: onp.ArrayND[np.bool_ | npc.integer], check_finite: bool | None =
2530
def whiten(obs: onp.ArrayND[_InexactT], check_finite: bool | None = None) -> onp.Array2D[_InexactT]: ...
2631

2732
#
33+
@overload # float32
34+
def vq( # type: ignore[overload-overlap]
35+
obs: onp.CanArrayND[np.float32], code_book: _ToFloat32_2D, check_finite: bool = True
36+
) -> tuple[onp.Array1D[np.int32], onp.Array1D[np.float32]]: ...
2837
@overload # float64
2938
def vq(
30-
obs: onp.ToJustFloat64_2D, code_book: onp.ToInt2D | onp.ToFloat64_2D, check_finite: bool = True
39+
obs: onp.ToJustFloat64_2D, code_book: _AsFloat64_2D, check_finite: bool = True
3140
) -> tuple[onp.Array1D[np.int32], onp.Array1D[np.float64]]: ...
32-
@overload # float32
33-
def vq(
34-
obs: onp.CanArrayND[np.float32], code_book: onp.ToInt2D | onp.CanArrayND[np.float16 | np.float32], check_finite: bool = True
35-
) -> tuple[onp.Array1D[np.int32], onp.Array1D[np.float32]]: ...
3641
@overload # floating
3742
def vq(
3843
obs: onp.ToJustFloat2D, code_book: onp.ToFloat2D, check_finite: bool = True
@@ -49,32 +54,32 @@ def py_vq(
4954
) -> tuple[onp.Array1D[np.intp], onp.Array1D[npc.floating]]: ...
5055

5156
#
52-
@overload # float64
53-
def kmeans(
54-
obs: onp.ToJustFloat64_2D,
55-
k_or_guess: onp.ToJustInt | onp.ToFloat64_ND,
57+
@overload # float32
58+
def kmeans( # type: ignore[overload-overlap]
59+
obs: onp.CanArrayND[np.float32],
60+
k_or_guess: int | _ToFloat32_2D,
5661
iter: int = 20,
5762
thresh: float = 1e-5,
5863
check_finite: bool = True,
5964
*,
6065
seed: onp.random.ToRNG | None = None,
6166
rng: onp.random.ToRNG | None = None,
62-
) -> tuple[onp.Array2D[np.float64], np.float64]: ...
63-
@overload # float32
67+
) -> tuple[onp.Array2D[np.float32], np.float32]: ...
68+
@overload # float64
6469
def kmeans(
65-
obs: onp.CanArrayND[np.float32],
66-
k_or_guess: onp.ToJustInt | onp.ToFloatND,
70+
obs: onp.ToJustFloat64_2D,
71+
k_or_guess: int | _AsFloat64_2D,
6772
iter: int = 20,
6873
thresh: float = 1e-5,
6974
check_finite: bool = True,
7075
*,
7176
seed: onp.random.ToRNG | None = None,
7277
rng: onp.random.ToRNG | None = None,
73-
) -> tuple[onp.Array2D[np.float32], np.float32]: ...
78+
) -> tuple[onp.Array2D[np.float64], np.float64]: ...
7479
@overload # floating
7580
def kmeans(
7681
obs: onp.ToJustFloat2D,
77-
k_or_guess: onp.ToJustInt | onp.ToFloatND,
82+
k_or_guess: int | onp.ToFloat2D,
7883
iter: int = 20,
7984
thresh: float = 1e-5,
8085
check_finite: bool = True,
@@ -104,10 +109,10 @@ def _missing_raise() -> None: ... # undocumented
104109
_valid_miss_meth: Final[dict[str, Callable[[], None]]] = ... # undocumented
105110

106111
#
107-
@overload # float64
108-
def kmeans2(
109-
data: onp.ToJustFloat64_1D | onp.ToJustFloat64_2D,
110-
k: onp.ToJustInt | onp.ToFloatND,
112+
@overload # float32
113+
def kmeans2( # type: ignore[overload-overlap]
114+
data: onp.CanArrayND[np.float32],
115+
k: int | _ToFloat32_2D,
111116
iter: int = 10,
112117
thresh: float = 1e-5,
113118
minit: _InitMethod = "random",
@@ -116,11 +121,11 @@ def kmeans2(
116121
*,
117122
seed: onp.random.ToRNG | None = None,
118123
rng: onp.random.ToRNG | None = None,
119-
) -> tuple[onp.Array2D[np.float64], onp.Array1D[np.int32]]: ...
120-
@overload # float32
124+
) -> tuple[onp.Array2D[np.float32], onp.Array1D[np.int32]]: ...
125+
@overload # float64
121126
def kmeans2(
122-
data: onp.CanArrayND[np.float32],
123-
k: onp.ToJustInt | onp.ToFloatND,
127+
data: onp.CanArrayND[np.float64] | _PyFloatMax2D,
128+
k: int | _AsFloat64_2D,
124129
iter: int = 10,
125130
thresh: float = 1e-5,
126131
minit: _InitMethod = "random",
@@ -129,11 +134,11 @@ def kmeans2(
129134
*,
130135
seed: onp.random.ToRNG | None = None,
131136
rng: onp.random.ToRNG | None = None,
132-
) -> tuple[onp.Array2D[np.float32], onp.Array1D[np.int32]]: ...
137+
) -> tuple[onp.Array2D[np.float64], onp.Array1D[np.int32]]: ...
133138
@overload # floating
134139
def kmeans2(
135-
data: onp.ToJustFloat1D | onp.ToJustFloat2D,
136-
k: onp.ToJustInt | onp.ToFloatND,
140+
data: onp.CanArrayND[npc.floating] | _PyFloatMax2D,
141+
k: int | onp.ToFloat2D,
137142
iter: int = 10,
138143
thresh: float = 1e-5,
139144
minit: _InitMethod = "random",

0 commit comments

Comments
 (0)