1212from typing import NamedTuple
1313import inspect
1414
15- from ._helpers import array_namespace , _check_device
15+ from ._helpers import array_namespace , _check_device , device , is_torch_array
1616
1717# These functions are modified from the NumPy versions.
1818
@@ -264,6 +264,38 @@ def var(
264264) -> ndarray :
265265 return xp .var (x , axis = axis , ddof = correction , keepdims = keepdims , ** kwargs )
266266
267+ # cumulative_sum is renamed from cumsum, and adds the include_initial keyword
268+ # argument
269+
270+ def cumulative_sum (
271+ x : ndarray ,
272+ / ,
273+ xp ,
274+ * ,
275+ axis : Optional [int ] = None ,
276+ dtype : Optional [Dtype ] = None ,
277+ include_initial : bool = False ,
278+ ** kwargs
279+ ) -> ndarray :
280+ wrapped_xp = array_namespace (x )
281+
282+ # TODO: The standard is not clear about what should happen when x.ndim == 0.
283+ if axis is None :
284+ if x .ndim > 1 :
285+ raise ValueError ("axis must be specified in cumulative_sum for more than one dimension" )
286+ axis = 0
287+
288+ res = xp .cumsum (x , axis = axis , dtype = dtype , ** kwargs )
289+
290+ # np.cumsum does not support include_initial
291+ if include_initial :
292+ initial_shape = list (x .shape )
293+ initial_shape [axis ] = 1
294+ res = xp .concatenate (
295+ [wrapped_xp .zeros (shape = initial_shape , dtype = res .dtype , device = device (res )), res ],
296+ axis = axis ,
297+ )
298+ return res
267299
268300# The min and max argument names in clip are different and not optional in numpy, and type
269301# promotion behavior is different.
@@ -281,10 +313,11 @@ def _isscalar(a):
281313 return isinstance (a , (int , float , type (None )))
282314 min_shape = () if _isscalar (min ) else min .shape
283315 max_shape = () if _isscalar (max ) else max .shape
284- result_shape = xp .broadcast_shapes (x .shape , min_shape , max_shape )
285316
286317 wrapped_xp = array_namespace (x )
287318
319+ result_shape = xp .broadcast_shapes (x .shape , min_shape , max_shape )
320+
288321 # np.clip does type promotion but the array API clip requires that the
289322 # output have the same dtype as x. We do this instead of just downcasting
290323 # the result of xp.clip() to handle some corner cases better (e.g.,
@@ -305,20 +338,26 @@ def _isscalar(a):
305338
306339 # At least handle the case of Python integers correctly (see
307340 # https://github.com/numpy/numpy/pull/26892).
308- if type (min ) is int and min <= xp .iinfo (x .dtype ).min :
341+ if type (min ) is int and min <= wrapped_xp .iinfo (x .dtype ).min :
309342 min = None
310- if type (max ) is int and max >= xp .iinfo (x .dtype ).max :
343+ if type (max ) is int and max >= wrapped_xp .iinfo (x .dtype ).max :
311344 max = None
312345
313346 if out is None :
314- out = wrapped_xp .asarray (xp .broadcast_to (x , result_shape ), copy = True )
347+ out = wrapped_xp .asarray (xp .broadcast_to (x , result_shape ),
348+ copy = True , device = device (x ))
315349 if min is not None :
316- a = xp .broadcast_to (xp .asarray (min ), result_shape )
350+ if is_torch_array (x ) and x .dtype == xp .float64 and _isscalar (min ):
351+ # Avoid loss of precision due to torch defaulting to float32
352+ min = wrapped_xp .asarray (min , dtype = xp .float64 )
353+ a = xp .broadcast_to (wrapped_xp .asarray (min , device = device (x )), result_shape )
317354 ia = (out < a ) | xp .isnan (a )
318355 # torch requires an explicit cast here
319356 out [ia ] = wrapped_xp .astype (a [ia ], out .dtype )
320357 if max is not None :
321- b = xp .broadcast_to (xp .asarray (max ), result_shape )
358+ if is_torch_array (x ) and x .dtype == xp .float64 and _isscalar (max ):
359+ max = wrapped_xp .asarray (max , dtype = xp .float64 )
360+ b = xp .broadcast_to (wrapped_xp .asarray (max , device = device (x )), result_shape )
322361 ib = (out > b ) | xp .isnan (b )
323362 out [ib ] = wrapped_xp .astype (b [ib ], out .dtype )
324363 # Return a scalar for 0-D
@@ -389,42 +428,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
389428 raise ValueError ("nonzero() does not support zero-dimensional arrays" )
390429 return xp .nonzero (x , ** kwargs )
391430
392- # sum() and prod() should always upcast when dtype=None
393- def sum (
394- x : ndarray ,
395- / ,
396- xp ,
397- * ,
398- axis : Optional [Union [int , Tuple [int , ...]]] = None ,
399- dtype : Optional [Dtype ] = None ,
400- keepdims : bool = False ,
401- ** kwargs ,
402- ) -> ndarray :
403- # `xp.sum` already upcasts integers, but not floats or complexes
404- if dtype is None :
405- if x .dtype == xp .float32 :
406- dtype = xp .float64
407- elif x .dtype == xp .complex64 :
408- dtype = xp .complex128
409- return xp .sum (x , axis = axis , dtype = dtype , keepdims = keepdims , ** kwargs )
410-
411- def prod (
412- x : ndarray ,
413- / ,
414- xp ,
415- * ,
416- axis : Optional [Union [int , Tuple [int , ...]]] = None ,
417- dtype : Optional [Dtype ] = None ,
418- keepdims : bool = False ,
419- ** kwargs ,
420- ) -> ndarray :
421- if dtype is None :
422- if x .dtype == xp .float32 :
423- dtype = xp .float64
424- elif x .dtype == xp .complex64 :
425- dtype = xp .complex128
426- return xp .prod (x , dtype = dtype , axis = axis , keepdims = keepdims , ** kwargs )
427-
428431# ceil, floor, and trunc return integers for integer inputs
429432
430433def ceil (x : ndarray , / , xp , ** kwargs ) -> ndarray :
@@ -521,10 +524,17 @@ def isdtype(
521524 # array_api_strict implementation will be very strict.
522525 return dtype == kind
523526
527+ # unstack is a new function in the 2023.12 array API standard
528+ def unstack (x : ndarray , / , xp , * , axis : int = 0 ) -> Tuple [ndarray , ...]:
529+ if x .ndim == 0 :
530+ raise ValueError ("Input array must be at least 1-d." )
531+ return tuple (xp .moveaxis (x , axis , 0 ))
532+
524533__all__ = ['arange' , 'empty' , 'empty_like' , 'eye' , 'full' , 'full_like' ,
525534 'linspace' , 'ones' , 'ones_like' , 'zeros' , 'zeros_like' ,
526535 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
527536 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
528- 'astype' , 'std' , 'var' , 'clip' , 'permute_dims' , 'reshape' , 'argsort' ,
529- 'sort' , 'nonzero' , 'sum' , 'prod' , 'ceil' , 'floor' , 'trunc' ,
530- 'matmul' , 'matrix_transpose' , 'tensordot' , 'vecdot' , 'isdtype' ]
537+ 'astype' , 'std' , 'var' , 'cumulative_sum' , 'clip' , 'permute_dims' ,
538+ 'reshape' , 'argsort' , 'sort' , 'nonzero' , 'ceil' , 'floor' , 'trunc' ,
539+ 'matmul' , 'matrix_transpose' , 'tensordot' , 'vecdot' , 'isdtype' ,
540+ 'unstack' ]
0 commit comments