11from __future__ import annotations
22
33from builtins import type as type_t
4- from collections .abc import Callable
4+ from collections .abc import (
5+ Callable ,
6+ Iterable ,
7+ )
58import decimal
69import numbers
710import sys
2730 ExtensionScalarOpsMixin ,
2831)
2932from pandas .core .indexers import check_array_indexer
33+ from pandas .core .series import Series
34+ from typing_extensions import Self
3035
3136from pandas ._typing import (
3237 ArrayLike ,
3338 AstypeArg ,
39+ Dtype ,
40+ ScalarIndexer ,
41+ SequenceIndexer ,
42+ SequenceNotStr ,
3443 TakeIndexer ,
44+ np_1darray ,
3545)
3646
3747from pandas .core .dtypes .base import ExtensionDtype
4151 pandas_dtype ,
4252)
4353
44- from tests import np_1darray
45-
4654
4755@register_extension_dtype
4856class DecimalDtype (ExtensionDtype ):
@@ -82,9 +90,9 @@ class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
8290 def __init__ (
8391 self ,
8492 values : list [decimal .Decimal | float ] | np .ndarray ,
85- dtype = None ,
86- copy = False ,
87- context = None ,
93+ dtype : DecimalDtype | None = None ,
94+ copy : bool = False ,
95+ context : decimal . Context | None = None ,
8896 ) -> None :
8997 for i , val in enumerate (values ):
9098 if is_float (val ):
@@ -111,25 +119,37 @@ def dtype(self) -> DecimalDtype:
111119 return self ._dtype
112120
113121 @classmethod
114- def _from_sequence (cls , scalars , dtype = None , copy = False ):
122+ def _from_sequence (
123+ cls ,
124+ scalars : list [decimal .Decimal | float ] | np .ndarray ,
125+ dtype : DecimalDtype | None = None ,
126+ copy : bool = False ,
127+ ) -> Self :
115128 return cls (scalars )
116129
117130 @classmethod
118- def _from_sequence_of_strings (cls , strings , dtype = None , copy = False ):
131+ def _from_sequence_of_strings (
132+ cls ,
133+ strings : SequenceNotStr [str ],
134+ dtype : DecimalDtype | None = None ,
135+ copy : bool = False ,
136+ ) -> Self :
119137 return cls ._from_sequence ([decimal .Decimal (x ) for x in strings ], dtype , copy )
120138
121139 @classmethod
122- def _from_factorized (cls , values , original ):
140+ def _from_factorized (
141+ cls , values : list [decimal .Decimal | float ] | np .ndarray , original : Any
142+ ) -> Self :
123143 return cls (values )
124144
125145 _HANDLED_TYPES = (decimal .Decimal , numbers .Number , np .ndarray )
126146
127147 def to_numpy (
128148 self ,
129- dtype = None ,
149+ dtype : np . typing . DTypeLike | None = None ,
130150 copy : bool = False ,
131151 na_value : object = no_default ,
132- decimals = None ,
152+ decimals : int | None = None ,
133153 ) -> np .ndarray :
134154 result = np .asarray (self , dtype = dtype )
135155 if decimals is not None :
@@ -138,7 +158,7 @@ def to_numpy(
138158
139159 def __array_ufunc__ (
140160 self , ufunc : np .ufunc , method : str , * inputs : Any , ** kwargs : Any
141- ):
161+ ) -> arraylike . dispatch_ufunc_with_out : # type: ignore[name-defined] # pyright: ignore[reportAttributeAccessIssue]
142162 #
143163 if not all (
144164 isinstance (t , self ._HANDLED_TYPES + (DecimalArray ,)) for t in inputs
@@ -160,7 +180,14 @@ def __array_ufunc__(
160180 if result is not NotImplemented :
161181 return result
162182
163- def reconstruct (x ) -> decimal .Decimal | numbers .Number | DecimalArray :
183+ def reconstruct (
184+ x : (
185+ decimal .Decimal
186+ | numbers .Number
187+ | list [decimal .Decimal | float ]
188+ | np .ndarray
189+ ),
190+ ) -> decimal .Decimal | numbers .Number | DecimalArray :
164191 if isinstance (x , (decimal .Decimal , numbers .Number )):
165192 return x
166193 return DecimalArray ._from_sequence (x )
@@ -169,15 +196,17 @@ def reconstruct(x) -> decimal.Decimal | numbers.Number | DecimalArray:
169196 return tuple (reconstruct (x ) for x in result )
170197 return reconstruct (result )
171198
172- def __getitem__ (self , item ) :
199+ def __getitem__ (self , item : ScalarIndexer | SequenceIndexer ) -> Any :
173200 if isinstance (item , numbers .Integral ):
174201 return self ._data [item ]
175202 # array, slice.
176- item = check_array_indexer (self , item )
203+ item = check_array_indexer (
204+ self , item # type: ignore[arg-type] # pyright: ignore[reportArgumentType,reportCallIssue]
205+ )
177206 return type (self )(self ._data [item ])
178207
179208 def take (
180- self , indexer : TakeIndexer , * , allow_fill : bool = False , fill_value = None
209+ self , indexer : TakeIndexer , * , allow_fill : bool = False , fill_value : Any = None
181210 ) -> DecimalArray :
182211 from pandas .api .extensions import take
183212
@@ -208,21 +237,26 @@ def astype(self, dtype, copy=True):
208237
209238 return super ().astype (dtype , copy = copy )
210239
211- def __setitem__ (self , key , value ) -> None :
240+ def __setitem__ (self , key : object , value : decimal . _DecimalNew ) -> None :
212241 if is_list_like (value ):
213242 if is_scalar (key ):
214243 raise ValueError ("setting an array element with a sequence." )
215- value = [decimal .Decimal (v ) for v in value ]
244+ value = [ # type: ignore[assignment]
245+ decimal .Decimal (v ) # type: ignore[arg-type]
246+ for v in value # type: ignore[union-attr] # pyright: ignore[reportAssignmentType,reportGeneralTypeIssues]
247+ ]
216248 else :
217249 value = decimal .Decimal (value )
218250
219- key = check_array_indexer (self , key )
220- self ._data [key ] = value
251+ key = check_array_indexer ( # type: ignore[call-overload]
252+ self , key # pyright: ignore[reportArgumentType,reportCallIssue]
253+ )
254+ self ._data [key ] = value # type: ignore[call-overload] # pyright: ignore[reportArgumentType,reportCallIssue]
221255
222256 def __len__ (self ) -> int :
223257 return len (self ._data )
224258
225- def __contains__ (self , item ) -> bool | np .bool_ :
259+ def __contains__ (self , item : Any ) -> bool | np .bool_ :
226260 if not isinstance (item , decimal .Decimal ):
227261 return False
228262 if item .is_nan ():
@@ -236,20 +270,20 @@ def nbytes(self) -> int:
236270 return n * sys .getsizeof (self [0 ])
237271 return 0
238272
239- def isna (self ):
273+ def isna (self ) -> np_1darray [ np . bool_ ] :
240274 return np .array ([x .is_nan () for x in self ._data ], dtype = bool )
241275
242276 @property
243277 def _na_value (self ) -> decimal .Decimal :
244278 return decimal .Decimal ("NaN" )
245279
246- def _formatter (self , boxed = False ) -> Callable [..., str ]:
280+ def _formatter (self , boxed : bool = False ) -> Callable [..., str ]:
247281 if boxed :
248282 return "Decimal: {}" .format
249283 return repr
250284
251285 @classmethod
252- def _concat_same_type (cls , to_concat ) :
286+ def _concat_same_type (cls , to_concat : Iterable [ Self ]) -> Self :
253287 return cls (np .concatenate ([x ._data for x in to_concat ]))
254288
255289 def _reduce (self , name : str , * , skipna : bool = True , ** kwargs : Any ) -> Any :
@@ -271,9 +305,11 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs: Any) -> Any:
271305 ) from err
272306 return op (axis = 0 )
273307
274- def _cmp_method (self , other , op ) -> np .ndarray [tuple [int ], np .dtype [np .bool_ ]]:
308+ def _cmp_method (
309+ self , other : Any , op : Callable [[Self , ExtensionArray | list [Any ]], bool ]
310+ ) -> np_1darray [np .bool_ ]:
275311 # For use with OpsMixin
276- def convert_values (param ) -> ExtensionArray | list [Any ]:
312+ def convert_values (param : Any ) -> ExtensionArray | list [Any ]:
277313 if isinstance (param , ExtensionArray ) or is_list_like (param ):
278314 ovalues = param
279315 else :
@@ -292,7 +328,7 @@ def convert_values(param) -> ExtensionArray | list[Any]:
292328 np .ndarray [tuple [int ], np .dtype [np .bool_ ]], np .asarray (res , dtype = bool )
293329 )
294330
295- def value_counts (self , dropna : bool = True ):
331+ def value_counts (self , dropna : bool = True ) -> Series :
296332 from pandas .core .algorithms import value_counts
297333
298334 return value_counts (self .to_numpy (), dropna = dropna )
0 commit comments