@@ -420,7 +420,9 @@ def _normalize_axes(axis, ndim):
420420 for a in axis :
421421 if a < lower or a > upper :
422422 # Match paddle error message (e.g., from sum())
423- raise IndexError (f"Dimension out of range (expected to be in range of [{ lower } , { upper } ], but got { a } " )
423+ raise IndexError (
424+ f"Dimension out of range (expected to be in range of [{ lower } , { upper } ], but got { a } "
425+ )
424426 if a < 0 :
425427 a = a + ndim
426428 if a in axes :
@@ -480,7 +482,9 @@ def prod(
480482
481483 # paddle.prod doesn't support multiple axes
482484 if isinstance (axis , tuple ):
483- return _reduce_multiple_axes (paddle .prod , x , axis , keepdim = keepdims , dtype = dtype , ** kwargs )
485+ return _reduce_multiple_axes (
486+ paddle .prod , x , axis , keepdim = keepdims , dtype = dtype , ** kwargs
487+ )
484488 if axis is None :
485489 # paddle doesn't support keepdims with axis=None
486490 res = paddle .prod (x , dtype = dtype , ** kwargs )
@@ -610,7 +614,9 @@ def std(
610614 if isinstance (correction , float ):
611615 _correction = int (correction )
612616 if correction != _correction :
613- raise NotImplementedError ("float correction in paddle std() is not yet supported" )
617+ raise NotImplementedError (
618+ "float correction in paddle std() is not yet supported"
619+ )
614620 elif isinstance (correction , int ):
615621 if correction not in [0 , 1 ]:
616622 raise NotImplementedError ("correction only can be 0 or 1" )
@@ -648,7 +654,9 @@ def var(
648654 if isinstance (correction , float ):
649655 _correction = int (correction )
650656 if correction != _correction :
651- raise NotImplementedError ("float correction in paddle std() is not yet supported" )
657+ raise NotImplementedError (
658+ "float correction in paddle std() is not yet supported"
659+ )
652660 elif isinstance (correction , int ):
653661 if correction not in [0 , 1 ]:
654662 raise NotImplementedError ("correction only can be 0 or 1" )
@@ -709,7 +717,9 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
709717
710718# The axis parameter doesn't work for flip() and roll()
711719# accept axis=None
712- def flip (x : array , / , * , axis : Optional [Union [int , Tuple [int , ...]]] = None , ** kwargs ) -> array :
720+ def flip (
721+ x : array , / , * , axis : Optional [Union [int , Tuple [int , ...]]] = None , ** kwargs
722+ ) -> array :
713723 if axis is None :
714724 axis = tuple (range (x .ndim ))
715725 # paddle.flip doesn't accept dim as an int but the method does
@@ -738,21 +748,27 @@ def where(condition: array, x1: array, x2: array, /) -> array:
738748 return paddle .where (condition , x1 , x2 )
739749
740750
741- def empty_like (x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ) -> array :
751+ def empty_like (
752+ x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
753+ ) -> array :
742754 out = paddle .empty_like (x , dtype = dtype )
743755 if device is not None :
744756 out = out .to (device )
745757 return out
746758
747759
748- def zeros_like (x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ) -> array :
760+ def zeros_like (
761+ x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
762+ ) -> array :
749763 out = paddle .zeros_like (x , dtype = dtype )
750764 if device is not None :
751765 out = out .to (device )
752766 return out
753767
754768
755- def ones_like (x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ) -> array :
769+ def ones_like (
770+ x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
771+ ) -> array :
756772 out = paddle .ones_like (x , dtype = dtype )
757773 if device is not None :
758774 out = out .to (device )
@@ -774,7 +790,9 @@ def full_like(
774790
775791
776792# paddle.reshape doesn't have the copy keyword
777- def reshape (x : array , / , shape : Tuple [int , ...], copy : Optional [bool ] = None , ** kwargs ) -> array :
793+ def reshape (
794+ x : array , / , shape : Tuple [int , ...], copy : Optional [bool ] = None , ** kwargs
795+ ) -> array :
778796 return paddle .reshape (x , shape , ** kwargs )
779797
780798
@@ -825,7 +843,9 @@ def linspace(
825843 ** kwargs ,
826844) -> array :
827845 if not endpoint :
828- return paddle .linspace (start , stop , num + 1 , dtype = dtype , ** kwargs ).to (device )[:- 1 ]
846+ return paddle .linspace (start , stop , num + 1 , dtype = dtype , ** kwargs ).to (device )[
847+ :- 1
848+ ]
829849 return paddle .linspace (start , stop , num , dtype = dtype , ** kwargs ).to (device )
830850
831851
@@ -890,7 +910,9 @@ def expand_dims(x: array, /, *, axis: int = 0) -> array:
890910 return paddle .unsqueeze (x , axis )
891911
892912
893- def astype (x : array , dtype : Dtype , / , * , copy : bool = True , device : Optional [Device ] = None ) -> array :
913+ def astype (
914+ x : array , dtype : Dtype , / , * , copy : bool = True , device : Optional [Device ] = None
915+ ) -> array :
894916 # if copy is not None:
895917 # raise NotImplementedError("paddle.astype doesn't yet support the copy keyword")
896918 t = x .to (dtype , device = device )
@@ -1036,7 +1058,7 @@ def sign(x: array, /) -> array:
10361058 else :
10371059 out = paddle .sign (x )
10381060 if paddle .is_floating_point (x ):
1039- out = paddle .where (paddle .isnan (x ), paddle .nan , out )
1061+ out = paddle .where (paddle .isnan (x ), paddle .full ( x . shape , paddle . nan ) , out )
10401062 return out
10411063
10421064
@@ -1083,7 +1105,8 @@ def asarray(
10831105 return obj
10841106 else :
10851107 raise NotImplementedError (
1086- "asarray(obj, ..., copy=False) is not supported " "for obj do not has '__dlpack__()' method"
1108+ "asarray(obj, ..., copy=False) is not supported "
1109+ "for obj do not has '__dlpack__()' method"
10871110 )
10881111 elif copy is True :
10891112 obj = np .array (obj , copy = True )
@@ -1164,11 +1187,18 @@ def _isscalar(a):
11641187
11651188
11661189def cumulative_sum (
1167- x : array , / , * , axis : Optional [int ] = None , dtype : Optional [Dtype ] = None , include_initial : bool = False
1190+ x : array ,
1191+ / ,
1192+ * ,
1193+ axis : Optional [int ] = None ,
1194+ dtype : Optional [Dtype ] = None ,
1195+ include_initial : bool = False ,
11681196) -> array :
11691197 if axis is None :
11701198 if x .ndim > 1 :
1171- raise ValueError ("axis must be specified in cumulative_sum for more than one dimension" )
1199+ raise ValueError (
1200+ "axis must be specified in cumulative_sum for more than one dimension"
1201+ )
11721202 axis = 0
11731203
11741204 res = paddle .cumsum (x , axis = axis , dtype = dtype )
@@ -1185,7 +1215,12 @@ def cumulative_sum(
11851215
11861216
11871217def searchsorted (
1188- x1 : array , x2 : array , / , * , side : Literal ["left" , "right" ] = "left" , sorter : array | None = None
1218+ x1 : array ,
1219+ x2 : array ,
1220+ / ,
1221+ * ,
1222+ side : Literal ["left" , "right" ] = "left" ,
1223+ sorter : array | None = None ,
11891224) -> array :
11901225 if sorter is None :
11911226 return paddle .searchsorted (x1 , x2 , right = (side == "right" ))
0 commit comments