Skip to content

Commit 86965dd

Browse files
committed
add back usm_type tests in simplified framework
1 parent 16361be commit 86965dd

File tree

3 files changed

+298
-38
lines changed

3 files changed

+298
-38
lines changed

dpctl/tests/elementwise/test_abs.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import dpctl.tensor as dpt
2424
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2525

26-
from .utils import _all_dtypes, _complex_fp_dtypes, _real_fp_dtypes, _usm_types
26+
from .utils import _all_dtypes, _complex_fp_dtypes, _real_fp_dtypes
2727

2828

2929
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -51,25 +51,6 @@ def test_abs_out_type(dtype):
5151
assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.abs(X)))
5252

5353

54-
@pytest.mark.parametrize("usm_type", _usm_types)
55-
def test_abs_usm_type(usm_type):
56-
q = get_queue_or_skip()
57-
58-
arg_dt = np.dtype("i4")
59-
input_shape = (10, 10, 10, 10)
60-
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
61-
X[..., 0::2] = 1
62-
X[..., 1::2] = 0
63-
64-
Y = dpt.abs(X)
65-
assert Y.usm_type == X.usm_type
66-
assert Y.sycl_queue == X.sycl_queue
67-
assert Y.flags.c_contiguous
68-
69-
expected_Y = dpt.asnumpy(X)
70-
assert np.allclose(dpt.asnumpy(Y), expected_Y)
71-
72-
7354
def test_abs_types_property():
7455
get_queue_or_skip()
7556
types = dpt.abs.types

dpctl/tests/elementwise/test_add.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2727
from dpctl.utils import ExecutionPlacementError
2828

29-
from .utils import _all_dtypes, _compare_dtypes, _usm_types
29+
from .utils import _all_dtypes, _compare_dtypes
3030

3131

3232
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
@@ -71,23 +71,6 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
7171
assert (dpt.asnumpy(r2) == np.full(r2.shape, 2, dtype=r2.dtype)).all()
7272

7373

74-
@pytest.mark.parametrize("op1_usm_type", _usm_types)
75-
@pytest.mark.parametrize("op2_usm_type", _usm_types)
76-
def test_add_usm_type_matrix(op1_usm_type, op2_usm_type):
77-
get_queue_or_skip()
78-
79-
sz = 128
80-
ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type)
81-
ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type)
82-
83-
r = dpt.add(ar1, ar2)
84-
assert isinstance(r, dpt.usm_ndarray)
85-
expected_usm_type = dpctl.utils.get_coerced_usm_type(
86-
(op1_usm_type, op2_usm_type)
87-
)
88-
assert r.usm_type == expected_usm_type
89-
90-
9174
def test_add_order():
9275
get_queue_or_skip()
9376

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl
20+
import dpctl.tensor as dpt
21+
from dpctl.tests.helper import get_queue_or_skip
22+
23+
from .utils import _usm_types
24+
25+
26+
@pytest.mark.parametrize("usm_type", _usm_types)
27+
class TestUnaryUSMType:
28+
def unary_elementwise(self, fn, usm_type, dtype="f4"):
29+
q = get_queue_or_skip()
30+
x = dpt.asarray(
31+
[1, 2, 3, 4], dtype=dtype, usm_type=usm_type, sycl_queue=q
32+
)
33+
return getattr(dpt, fn)(x)
34+
35+
def test_abs(self, usm_type):
36+
self.unary_elementwise("abs", usm_type)
37+
38+
def test_acos(self, usm_type):
39+
self.unary_elementwise("acos", usm_type)
40+
41+
def test_acosh(self, usm_type):
42+
self.unary_elementwise("acosh", usm_type)
43+
44+
def test_angle(self, usm_type):
45+
self.unary_elementwise("angle", usm_type, dtype="c8")
46+
47+
def test_asin(self, usm_type):
48+
self.unary_elementwise("asin", usm_type)
49+
50+
def test_asinh(self, usm_type):
51+
self.unary_elementwise("asinh", usm_type)
52+
53+
def test_atan(self, usm_type):
54+
self.unary_elementwise("atan", usm_type)
55+
56+
def test_atanh(self, usm_type):
57+
self.unary_elementwise("atanh", usm_type)
58+
59+
def test_bitwise_invert(self, usm_type):
60+
self.unary_elementwise("bitwise_invert", usm_type, dtype="i4")
61+
62+
def test_cbrt(self, usm_type):
63+
self.unary_elementwise("cbrt", usm_type)
64+
65+
def test_ceil(self, usm_type):
66+
self.unary_elementwise("ceil", usm_type)
67+
68+
def test_conj(self, usm_type):
69+
self.unary_elementwise("conj", usm_type)
70+
71+
def test_cos(self, usm_type):
72+
self.unary_elementwise("cos", usm_type)
73+
74+
def test_cosh(self, usm_type):
75+
self.unary_elementwise("cosh", usm_type)
76+
77+
def test_exp(self, usm_type):
78+
self.unary_elementwise("exp", usm_type)
79+
80+
def test_exp2(self, usm_type):
81+
self.unary_elementwise("exp2", usm_type)
82+
83+
def test_expm1(self, usm_type):
84+
self.unary_elementwise("expm1", usm_type)
85+
86+
def test_floor(self, usm_type):
87+
self.unary_elementwise("floor", usm_type)
88+
89+
def test_imag(self, usm_type):
90+
self.unary_elementwise("imag", usm_type)
91+
92+
def test_isfinite(self, usm_type):
93+
self.unary_elementwise("isfinite", usm_type)
94+
95+
def test_isinf(self, usm_type):
96+
self.unary_elementwise("isinf", usm_type)
97+
98+
def test_isnan(self, usm_type):
99+
self.unary_elementwise("isnan", usm_type)
100+
101+
def test_log(self, usm_type):
102+
self.unary_elementwise("log", usm_type)
103+
104+
def test_log1p(self, usm_type):
105+
self.unary_elementwise("log1p", usm_type)
106+
107+
def test_log2(self, usm_type):
108+
self.unary_elementwise("log2", usm_type)
109+
110+
def test_log10(self, usm_type):
111+
self.unary_elementwise("log10", usm_type)
112+
113+
def test_logical_not(self, usm_type):
114+
self.unary_elementwise("logical_not", usm_type, dtype="i4")
115+
116+
def test_negative(self, usm_type):
117+
self.unary_elementwise("negative", usm_type)
118+
119+
def test_positive(self, usm_type):
120+
self.unary_elementwise("positive", usm_type)
121+
122+
def test_proj(self, usm_type):
123+
self.unary_elementwise("proj", usm_type, dtype="c8")
124+
125+
def test_real(self, usm_type):
126+
self.unary_elementwise("real", usm_type, dtype="c8")
127+
128+
def test_reciprocal(self, usm_type):
129+
self.unary_elementwise("reciprocal", usm_type)
130+
131+
def test_round(self, usm_type):
132+
self.unary_elementwise("round", usm_type)
133+
134+
def test_rsqrt(self, usm_type):
135+
self.unary_elementwise("rsqrt", usm_type)
136+
137+
def test_sign(self, usm_type):
138+
self.unary_elementwise("sign", usm_type)
139+
140+
def test_signbit(self, usm_type):
141+
self.unary_elementwise("signbit", usm_type)
142+
143+
def test_sin(self, usm_type):
144+
self.unary_elementwise("sin", usm_type)
145+
146+
def test_sinh(self, usm_type):
147+
self.unary_elementwise("sinh", usm_type)
148+
149+
def test_square(self, usm_type):
150+
self.unary_elementwise("square", usm_type)
151+
152+
def test_sqrt(self, usm_type):
153+
self.unary_elementwise("sqrt", usm_type)
154+
155+
def test_tan(self, usm_type):
156+
self.unary_elementwise("tan", usm_type)
157+
158+
def test_tanh(self, usm_type):
159+
self.unary_elementwise("tanh", usm_type)
160+
161+
def test_trunc(self, usm_type):
162+
self.unary_elementwise("trunc", usm_type)
163+
164+
def test_usm_basic(self, usm_type):
165+
q = get_queue_or_skip()
166+
167+
sz = 128
168+
dt = dpt.int32
169+
x = dpt.ones(sz, dtype=dt, usm_type=usm_type, sycl_queue=q)
170+
171+
r = dpt.abs(x)
172+
assert isinstance(r, dpt.usm_ndarray)
173+
assert r.usm_type == x.usm_type
174+
175+
176+
@pytest.mark.parametrize("op1_usm_type", _usm_types)
177+
@pytest.mark.parametrize("op2_usm_type", _usm_types)
178+
class TestBinaryUSMType:
179+
def binary_elementwise(self, fn, op1_usm_type, op2_usm_type, dtype="f4"):
180+
q = get_queue_or_skip()
181+
x = dpt.asarray(
182+
[1, 2, 3, 4, 5, 6], dtype=dtype, usm_type=op1_usm_type, sycl_queue=q
183+
)
184+
y = dpt.asarray(
185+
[1, 2, 3, 4, 5, 6], dtype=dtype, usm_type=op2_usm_type, sycl_queue=q
186+
)
187+
return getattr(dpt, fn)(x, y)
188+
189+
def test_add(self, op1_usm_type, op2_usm_type):
190+
self.binary_elementwise("add", op1_usm_type, op2_usm_type)
191+
192+
def test_atan2(self, op1_usm_type, op2_usm_type):
193+
self.binary_elementwise("atan2", op1_usm_type, op2_usm_type)
194+
195+
def test_bitwise_and(self, op1_usm_type, op2_usm_type):
196+
self.binary_elementwise(
197+
"bitwise_and", op1_usm_type, op2_usm_type, dtype="i4"
198+
)
199+
200+
def test_bitwise_left_shift(self, op1_usm_type, op2_usm_type):
201+
self.binary_elementwise(
202+
"bitwise_left_shift", op1_usm_type, op2_usm_type, dtype="i4"
203+
)
204+
205+
def test_bitwise_or(self, op1_usm_type, op2_usm_type):
206+
self.binary_elementwise(
207+
"bitwise_or", op1_usm_type, op2_usm_type, dtype="i4"
208+
)
209+
210+
def test_bitwise_right_shift(self, op1_usm_type, op2_usm_type):
211+
self.binary_elementwise(
212+
"bitwise_right_shift", op1_usm_type, op2_usm_type, dtype="i4"
213+
)
214+
215+
def test_bitwise_xor(self, op1_usm_type, op2_usm_type):
216+
self.binary_elementwise(
217+
"bitwise_xor", op1_usm_type, op2_usm_type, dtype="i4"
218+
)
219+
220+
def test_copysign(self, op1_usm_type, op2_usm_type):
221+
self.binary_elementwise("copysign", op1_usm_type, op2_usm_type)
222+
223+
def test_divide(self, op1_usm_type, op2_usm_type):
224+
self.binary_elementwise("divide", op1_usm_type, op2_usm_type)
225+
226+
def test_equal(self, op1_usm_type, op2_usm_type):
227+
self.binary_elementwise("equal", op1_usm_type, op2_usm_type)
228+
229+
def test_floor_divide(self, op1_usm_type, op2_usm_type):
230+
self.binary_elementwise("floor_divide", op1_usm_type, op2_usm_type)
231+
232+
def test_hypot(self, op1_usm_type, op2_usm_type):
233+
self.binary_elementwise("hypot", op1_usm_type, op2_usm_type)
234+
235+
def test_greater(self, op1_usm_type, op2_usm_type):
236+
self.binary_elementwise("greater", op1_usm_type, op2_usm_type)
237+
238+
def test_greater_equal(self, op1_usm_type, op2_usm_type):
239+
self.binary_elementwise("greater_equal", op1_usm_type, op2_usm_type)
240+
241+
def test_less(self, op1_usm_type, op2_usm_type):
242+
self.binary_elementwise("less", op1_usm_type, op2_usm_type)
243+
244+
def test_less_equal(self, op1_usm_type, op2_usm_type):
245+
self.binary_elementwise("less_equal", op1_usm_type, op2_usm_type)
246+
247+
def test_logaddexp(self, op1_usm_type, op2_usm_type):
248+
self.binary_elementwise("logaddexp", op1_usm_type, op2_usm_type)
249+
250+
def test_logical_and(self, op1_usm_type, op2_usm_type):
251+
self.binary_elementwise("logical_and", op1_usm_type, op2_usm_type)
252+
253+
def test_logical_or(self, op1_usm_type, op2_usm_type):
254+
self.binary_elementwise("logical_or", op1_usm_type, op2_usm_type)
255+
256+
def test_logical_xor(self, op1_usm_type, op2_usm_type):
257+
self.binary_elementwise("logical_xor", op1_usm_type, op2_usm_type)
258+
259+
def test_maximum(self, op1_usm_type, op2_usm_type):
260+
self.binary_elementwise("maximum", op1_usm_type, op2_usm_type)
261+
262+
def test_minimum(self, op1_usm_type, op2_usm_type):
263+
self.binary_elementwise("minimum", op1_usm_type, op2_usm_type)
264+
265+
def test_multiply(self, op1_usm_type, op2_usm_type):
266+
self.binary_elementwise("multiply", op1_usm_type, op2_usm_type)
267+
268+
def test_nextafter(self, op1_usm_type, op2_usm_type):
269+
self.binary_elementwise("nextafter", op1_usm_type, op2_usm_type)
270+
271+
def test_not_equal(self, op1_usm_type, op2_usm_type):
272+
self.binary_elementwise("not_equal", op1_usm_type, op2_usm_type)
273+
274+
def test_pow(self, op1_usm_type, op2_usm_type):
275+
self.binary_elementwise("pow", op1_usm_type, op2_usm_type)
276+
277+
def test_remainder(self, op1_usm_type, op2_usm_type):
278+
self.binary_elementwise("remainder", op1_usm_type, op2_usm_type)
279+
280+
def test_subtract(self, op1_usm_type, op2_usm_type):
281+
self.binary_elementwise("subtract", op1_usm_type, op2_usm_type)
282+
283+
def test_binary_usm_type_coercion(self, op1_usm_type, op2_usm_type):
284+
q = get_queue_or_skip()
285+
286+
sz = 128
287+
dt = dpt.int32
288+
ar1 = dpt.ones(sz, dtype=dt, usm_type=op1_usm_type, sycl_queue=q)
289+
ar2 = dpt.ones_like(ar1, dtype=dt, usm_type=op2_usm_type, sycl_queue=q)
290+
291+
r = dpt.add(ar1, ar2)
292+
assert isinstance(r, dpt.usm_ndarray)
293+
expected_usm_type = dpctl.utils.get_coerced_usm_type(
294+
(op1_usm_type, op2_usm_type)
295+
)
296+
assert r.usm_type == expected_usm_type

0 commit comments

Comments
 (0)