@@ -528,21 +528,27 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
528528
529529
530530def ceil (x : Array , / , xp : Namespace , ** kwargs : object ) -> Array :
531- if xp .issubdtype (x .dtype , xp .integer ):
532- return x
533- return xp .ceil (x , ** kwargs )
531+ result = xp .ceil (x , ** kwargs )
532+ if result .dtype != x .dtype :
533+ # numpy < 2: ceil(int array) is float
534+ result = xp .asarray (result , dtype = x .dtype )
535+ return result
534536
535537
536538def floor (x : Array , / , xp : Namespace , ** kwargs : object ) -> Array :
537- if xp .issubdtype (x .dtype , xp .integer ):
538- return x
539- return xp .floor (x , ** kwargs )
539+ result = xp .floor (x , ** kwargs )
540+ if result .dtype != x .dtype :
541+ # numpy < 2: floor(int array) is float
542+ result = xp .asarray (result , dtype = x .dtype )
543+ return result
540544
541545
542546def trunc (x : Array , / , xp : Namespace , ** kwargs : object ) -> Array :
543- if xp .issubdtype (x .dtype , xp .integer ):
544- return x
545- return xp .trunc (x , ** kwargs )
547+ result = xp .trunc (x , ** kwargs )
548+ if result .dtype != x .dtype :
549+ # numpy < 2: trunc(int array) is float
550+ result = xp .asarray (result , dtype = x .dtype )
551+ return result
546552
547553
548554# linear algebra functions
0 commit comments