@@ -120,6 +120,25 @@ def count_nonzero(
120120 return cp .expand_dims (result , axis )
121121 return result
122122
123+ # ceil, floor, and trunc return integers for integer inputs
124+
125+ def ceil (x : Array , / ) -> Array :
126+ if cp .issubdtype (x .dtype , cp .integer ):
127+ return x .copy ()
128+ return cp .ceil (x )
129+
130+
131+ def floor (x : Array , / ) -> Array :
132+ if cp .issubdtype (x .dtype , cp .integer ):
133+ return x .copy ()
134+ return cp .floor (x )
135+
136+
137+ def trunc (x : Array , / ) -> Array :
138+ if cp .issubdtype (x .dtype , cp .integer ):
139+ return x .copy ()
140+ return cp .trunc (x )
141+
123142
124143# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
125144def take_along_axis (x : Array , indices : Array , / , * , axis : int = - 1 ):
@@ -148,6 +167,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
148167 'atan2' , 'atanh' , 'bitwise_left_shift' ,
149168 'bitwise_invert' , 'bitwise_right_shift' ,
150169 'bool' , 'concat' , 'count_nonzero' , 'pow' , 'sign' ,
151- 'take_along_axis' ]
170+ 'ceil' , 'floor' , 'trunc' , ' take_along_axis' ]
152171
153172_all_ignore = ['cp' , 'get_xp' ]
0 commit comments