Skip to content

Commit 8f68787

Browse files
committed
add back usm_type tests in simplified framework
1 parent 16361be commit 8f68787

File tree

1 file changed

+314
-0
lines changed

1 file changed

+314
-0
lines changed
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl
20+
import dpctl.tensor as dpt
21+
from dpctl.tests.helper import get_queue_or_skip
22+
23+
from .utils import _usm_types
24+
25+
26+
@pytest.mark.parametrize("op1_usm_type", _usm_types)
27+
@pytest.mark.parametrize("op2_usm_type", _usm_types)
28+
def test_binary_usm_type_coercion(op1_usm_type, op2_usm_type):
29+
get_queue_or_skip()
30+
31+
sz = 128
32+
dt = dpt.int32
33+
ar1 = dpt.ones(sz, dtype=dt, usm_type=op1_usm_type)
34+
ar2 = dpt.ones_like(ar1, dtype=dt, usm_type=op2_usm_type)
35+
36+
r = dpt.add(ar1, ar2)
37+
assert isinstance(r, dpt.usm_ndarray)
38+
expected_usm_type = dpctl.utils.get_coerced_usm_type(
39+
(op1_usm_type, op2_usm_type)
40+
)
41+
assert r.usm_type == expected_usm_type
42+
43+
44+
@pytest.mark.parametrize("usm_type", _usm_types)
45+
class TestUnaryUSMType:
46+
def unary_elementwise(self, fn, usm_type, dtype="f4"):
47+
q = get_queue_or_skip()
48+
x = dpt.asarray(
49+
[1, 2, 3, 4], dtype=dtype, usm_type=usm_type, sycl_queue=q
50+
)
51+
return getattr(dpt, fn)(x)
52+
53+
def test_abs(self, usm_type):
54+
self.unary_elementwise("abs", usm_type)
55+
56+
def test_acos(self, usm_type):
57+
self.unary_elementwise("acos", usm_type)
58+
59+
def test_acosh(self, usm_type):
60+
self.unary_elementwise("acosh", usm_type)
61+
62+
def test_angle(self, usm_type):
63+
self.unary_elementwise("angle", usm_type, dtype="c8")
64+
65+
def test_asin(self, usm_type):
66+
self.unary_elementwise("asin", usm_type)
67+
68+
def test_asinh(self, usm_type):
69+
self.unary_elementwise("asinh", usm_type)
70+
71+
def test_atan(self, usm_type):
72+
self.unary_elementwise("atan", usm_type)
73+
74+
def test_atanh(self, usm_type):
75+
self.unary_elementwise("atanh", usm_type)
76+
77+
def test_bitwise_invert(self, usm_type):
78+
self.unary_elementwise("bitwise_invert", usm_type, dtype="i4")
79+
80+
def test_cbrt(self, usm_type):
81+
self.unary_elementwise("cbrt", usm_type)
82+
83+
def test_ceil(self, usm_type):
84+
self.unary_elementwise("ceil", usm_type)
85+
86+
def test_conj(self, usm_type):
87+
self.unary_elementwise("conj", usm_type)
88+
89+
def test_cos(self, usm_type):
90+
self.unary_elementwise("cos", usm_type)
91+
92+
def test_cosh(self, usm_type):
93+
self.unary_elementwise("cosh", usm_type)
94+
95+
def test_exp(self, usm_type):
96+
self.unary_elementwise("exp", usm_type)
97+
98+
def test_exp2(self, usm_type):
99+
self.unary_elementwise("exp2", usm_type)
100+
101+
def test_expm1(self, usm_type):
102+
self.unary_elementwise("expm1", usm_type)
103+
104+
def test_floor(self, usm_type):
105+
self.unary_elementwise("floor", usm_type)
106+
107+
def test_imag(self, usm_type):
108+
self.unary_elementwise("imag", usm_type)
109+
110+
def test_isfinite(self, usm_type):
111+
self.unary_elementwise("isfinite", usm_type)
112+
113+
def test_isinf(self, usm_type):
114+
self.unary_elementwise("isinf", usm_type)
115+
116+
def test_isnan(self, usm_type):
117+
self.unary_elementwise("isnan", usm_type)
118+
119+
def test_log(self, usm_type):
120+
self.unary_elementwise("log", usm_type)
121+
122+
def test_log1p(self, usm_type):
123+
self.unary_elementwise("log1p", usm_type)
124+
125+
def test_log2(self, usm_type):
126+
self.unary_elementwise("log2", usm_type)
127+
128+
def test_log10(self, usm_type):
129+
self.unary_elementwise("log10", usm_type)
130+
131+
def test_logical_not(self, usm_type):
132+
self.unary_elementwise("logical_not", usm_type, dtype="i4")
133+
134+
def test_negative(self, usm_type):
135+
self.unary_elementwise("negative", usm_type)
136+
137+
def test_positive(self, usm_type):
138+
self.unary_elementwise("positive", usm_type)
139+
140+
def test_proj(self, usm_type):
141+
self.unary_elementwise("proj", usm_type, dtype="c8")
142+
143+
def test_real(self, usm_type):
144+
self.unary_elementwise("real", usm_type, dtype="c8")
145+
146+
def test_reciprocal(self, usm_type):
147+
self.unary_elementwise("reciprocal", usm_type)
148+
149+
def test_round(self, usm_type):
150+
self.unary_elementwise("round", usm_type)
151+
152+
def test_rsqrt(self, usm_type):
153+
self.unary_elementwise("rsqrt", usm_type)
154+
155+
def test_sign(self, usm_type):
156+
self.unary_elementwise("sign", usm_type)
157+
158+
def test_signbit(self, usm_type):
159+
self.unary_elementwise("signbit", usm_type)
160+
161+
def test_sin(self, usm_type):
162+
self.unary_elementwise("sin", usm_type)
163+
164+
def test_sinh(self, usm_type):
165+
self.unary_elementwise("sinh", usm_type)
166+
167+
def test_square(self, usm_type):
168+
self.unary_elementwise("square", usm_type)
169+
170+
def test_sqrt(self, usm_type):
171+
self.unary_elementwise("sqrt", usm_type)
172+
173+
def test_tan(self, usm_type):
174+
self.unary_elementwise("tan", usm_type)
175+
176+
def test_tanh(self, usm_type):
177+
self.unary_elementwise("tanh", usm_type)
178+
179+
def test_trunc(self, usm_type):
180+
self.unary_elementwise("trunc", usm_type)
181+
182+
def test_usm_basic(self, usm_type):
183+
q = get_queue_or_skip()
184+
185+
sz = 128
186+
dt = dpt.int32
187+
x = dpt.ones(sz, dtype=dt, usm_type=usm_type, sycl_queue=q)
188+
189+
r = dpt.abs(x)
190+
assert isinstance(r, dpt.usm_ndarray)
191+
assert r.usm_type == x.usm_type
192+
193+
194+
@pytest.mark.parametrize("op1_usm_type", _usm_types)
195+
@pytest.mark.parametrize("op2_usm_type", _usm_types)
196+
class TestBinaryUSMType:
197+
def binary_elementwise(self, fn, op1_usm_type, op2_usm_type, dtype="f4"):
198+
q = get_queue_or_skip()
199+
x = dpt.asarray(
200+
[1, 2, 3, 4, 5, 6], dtype=dtype, usm_type=op1_usm_type, sycl_queue=q
201+
)
202+
y = dpt.asarray(
203+
[1, 2, 3, 4, 5, 6], dtype=dtype, usm_type=op2_usm_type, sycl_queue=q
204+
)
205+
return getattr(dpt, fn)(x, y)
206+
207+
def test_add(self, op1_usm_type, op2_usm_type):
208+
self.binary_elementwise("add", op1_usm_type, op2_usm_type)
209+
210+
def test_atan2(self, op1_usm_type, op2_usm_type):
211+
self.binary_elementwise("atan2", op1_usm_type, op2_usm_type)
212+
213+
def test_bitwise_and(self, op1_usm_type, op2_usm_type):
214+
self.binary_elementwise(
215+
"bitwise_and", op1_usm_type, op2_usm_type, dtype="i4"
216+
)
217+
218+
def test_bitwise_left_shift(self, op1_usm_type, op2_usm_type):
219+
self.binary_elementwise(
220+
"bitwise_left_shift", op1_usm_type, op2_usm_type, dtype="i4"
221+
)
222+
223+
def test_bitwise_or(self, op1_usm_type, op2_usm_type):
224+
self.binary_elementwise(
225+
"bitwise_or", op1_usm_type, op2_usm_type, dtype="i4"
226+
)
227+
228+
def test_bitwise_right_shift(self, op1_usm_type, op2_usm_type):
229+
self.binary_elementwise(
230+
"bitwise_right_shift", op1_usm_type, op2_usm_type, dtype="i4"
231+
)
232+
233+
def test_bitwise_xor(self, op1_usm_type, op2_usm_type):
234+
self.binary_elementwise(
235+
"bitwise_xor", op1_usm_type, op2_usm_type, dtype="i4"
236+
)
237+
238+
def test_copysign(self, op1_usm_type, op2_usm_type):
239+
self.binary_elementwise("copysign", op1_usm_type, op2_usm_type)
240+
241+
def test_divide(self, op1_usm_type, op2_usm_type):
242+
self.binary_elementwise("divide", op1_usm_type, op2_usm_type)
243+
244+
def test_equal(self, op1_usm_type, op2_usm_type):
245+
self.binary_elementwise("equal", op1_usm_type, op2_usm_type)
246+
247+
def test_floor_divide(self, op1_usm_type, op2_usm_type):
248+
self.binary_elementwise("floor_divide", op1_usm_type, op2_usm_type)
249+
250+
def test_hypot(self, op1_usm_type, op2_usm_type):
251+
self.binary_elementwise("hypot", op1_usm_type, op2_usm_type)
252+
253+
def test_greater(self, op1_usm_type, op2_usm_type):
254+
self.binary_elementwise("greater", op1_usm_type, op2_usm_type)
255+
256+
def test_greater_equal(self, op1_usm_type, op2_usm_type):
257+
self.binary_elementwise("greater_equal", op1_usm_type, op2_usm_type)
258+
259+
def test_less(self, op1_usm_type, op2_usm_type):
260+
self.binary_elementwise("less", op1_usm_type, op2_usm_type)
261+
262+
def test_less_equal(self, op1_usm_type, op2_usm_type):
263+
self.binary_elementwise("less_equal", op1_usm_type, op2_usm_type)
264+
265+
def test_logaddexp(self, op1_usm_type, op2_usm_type):
266+
self.binary_elementwise("logaddexp", op1_usm_type, op2_usm_type)
267+
268+
def test_logical_and(self, op1_usm_type, op2_usm_type):
269+
self.binary_elementwise("logical_and", op1_usm_type, op2_usm_type)
270+
271+
def test_logical_or(self, op1_usm_type, op2_usm_type):
272+
self.binary_elementwise("logical_or", op1_usm_type, op2_usm_type)
273+
274+
def test_logical_xor(self, op1_usm_type, op2_usm_type):
275+
self.binary_elementwise("logical_xor", op1_usm_type, op2_usm_type)
276+
277+
def test_maximum(self, op1_usm_type, op2_usm_type):
278+
self.binary_elementwise("maximum", op1_usm_type, op2_usm_type)
279+
280+
def test_minimum(self, op1_usm_type, op2_usm_type):
281+
self.binary_elementwise("minimum", op1_usm_type, op2_usm_type)
282+
283+
def test_multiply(self, op1_usm_type, op2_usm_type):
284+
self.binary_elementwise("multiply", op1_usm_type, op2_usm_type)
285+
286+
def test_nextafter(self, op1_usm_type, op2_usm_type):
287+
self.binary_elementwise("nextafter", op1_usm_type, op2_usm_type)
288+
289+
def test_not_equal(self, op1_usm_type, op2_usm_type):
290+
self.binary_elementwise("not_equal", op1_usm_type, op2_usm_type)
291+
292+
def test_pow(self, op1_usm_type, op2_usm_type):
293+
self.binary_elementwise("pow", op1_usm_type, op2_usm_type)
294+
295+
def test_remainder(self, op1_usm_type, op2_usm_type):
296+
self.binary_elementwise("remainder", op1_usm_type, op2_usm_type)
297+
298+
def test_subtract(self, op1_usm_type, op2_usm_type):
299+
self.binary_elementwise("subtract", op1_usm_type, op2_usm_type)
300+
301+
def test_binary_usm_type_coercion(self, op1_usm_type, op2_usm_type):
302+
q = get_queue_or_skip()
303+
304+
sz = 128
305+
dt = dpt.int32
306+
ar1 = dpt.ones(sz, dtype=dt, usm_type=op1_usm_type, sycl_queue=q)
307+
ar2 = dpt.ones_like(ar1, dtype=dt, usm_type=op2_usm_type, sycl_queue=q)
308+
309+
r = dpt.add(ar1, ar2)
310+
assert isinstance(r, dpt.usm_ndarray)
311+
expected_usm_type = dpctl.utils.get_coerced_usm_type(
312+
(op1_usm_type, op2_usm_type)
313+
)
314+
assert r.usm_type == expected_usm_type

0 commit comments

Comments
 (0)