14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
17
- import itertools
18
- import os
19
- import re
20
-
21
17
import numpy as np
22
18
import pytest
23
19
from numpy .testing import assert_allclose
34
30
(np .arctanh , dpt .atanh ),
35
31
]
36
32
_all_funcs = _hyper_funcs + _inv_hyper_funcs
37
- _dpt_funcs = [t [1 ] for t in _all_funcs ]
38
33
39
34
40
35
@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
@@ -45,17 +40,10 @@ def test_hyper_out_type(np_call, dpt_call, dtype):
45
40
46
41
a = 1 if np_call == np .arccosh else 0
47
42
48
- X = dpt .asarray (a , dtype = dtype , sycl_queue = q )
49
- expected_dtype = np_call (np .array (a , dtype = dtype )).dtype
50
- expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
51
- assert dpt_call (X ).dtype == expected_dtype
52
-
53
- X = dpt .asarray (a , dtype = dtype , sycl_queue = q )
43
+ x = dpt .asarray (a , dtype = dtype , sycl_queue = q )
54
44
expected_dtype = np_call (np .array (a , dtype = dtype )).dtype
55
45
expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
56
- Y = dpt .empty_like (X , dtype = expected_dtype )
57
- dpt_call (X , out = Y )
58
- assert_allclose (dpt .asnumpy (dpt_call (X )), dpt .asnumpy (Y ))
46
+ assert dpt_call (x ).dtype == expected_dtype
59
47
60
48
61
49
@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
@@ -119,79 +107,6 @@ def test_hyper_complex_contig(np_call, dpt_call, dtype):
119
107
)
120
108
121
109
122
- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
123
- @pytest .mark .parametrize ("usm_type" , ["device" , "shared" , "host" ])
124
- def test_hyper_usm_type (np_call , dpt_call , usm_type ):
125
- q = get_queue_or_skip ()
126
-
127
- arg_dt = np .dtype ("f4" )
128
- input_shape = (10 , 10 , 10 , 10 )
129
- X = dpt .empty (input_shape , dtype = arg_dt , usm_type = usm_type , sycl_queue = q )
130
- if np_call == np .arctanh :
131
- X [..., 0 ::2 ] = - 0.4
132
- X [..., 1 ::2 ] = 0.3
133
- elif np_call == np .arccosh :
134
- X [..., 0 ::2 ] = 2.2
135
- X [..., 1 ::2 ] = 5.5
136
- else :
137
- X [..., 0 ::2 ] = - 4.4
138
- X [..., 1 ::2 ] = 5.5
139
-
140
- Y = dpt_call (X )
141
- assert Y .usm_type == X .usm_type
142
- assert Y .sycl_queue == X .sycl_queue
143
- assert Y .flags .c_contiguous
144
-
145
- expected_Y = np_call (dpt .asnumpy (X ))
146
- tol = 8 * dpt .finfo (Y .dtype ).resolution
147
- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
148
-
149
-
150
- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
151
- @pytest .mark .parametrize ("dtype" , _all_dtypes )
152
- def test_hyper_order (np_call , dpt_call , dtype ):
153
- q = get_queue_or_skip ()
154
- skip_if_dtype_not_supported (dtype , q )
155
-
156
- arg_dt = np .dtype (dtype )
157
- input_shape = (4 , 4 , 4 , 4 )
158
- X = dpt .empty (input_shape , dtype = arg_dt , sycl_queue = q )
159
- if np_call == np .arctanh :
160
- X [..., 0 ::2 ] = - 0.4
161
- X [..., 1 ::2 ] = 0.3
162
- elif np_call == np .arccosh :
163
- X [..., 0 ::2 ] = 2.2
164
- X [..., 1 ::2 ] = 5.5
165
- else :
166
- X [..., 0 ::2 ] = - 4.4
167
- X [..., 1 ::2 ] = 5.5
168
-
169
- for perms in itertools .permutations (range (4 )):
170
- U = dpt .permute_dims (X [:, ::- 1 , ::- 1 , :], perms )
171
- with np .errstate (all = "ignore" ):
172
- expected_Y = np_call (dpt .asnumpy (U ))
173
- for ord in ["C" , "F" , "A" , "K" ]:
174
- Y = dpt_call (U , order = ord )
175
- tol = 8 * max (
176
- dpt .finfo (Y .dtype ).resolution ,
177
- np .finfo (expected_Y .dtype ).resolution ,
178
- )
179
- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
180
-
181
-
182
- @pytest .mark .parametrize ("callable" , _dpt_funcs )
183
- @pytest .mark .parametrize ("dtype" , _all_dtypes )
184
- def test_hyper_error_dtype (callable , dtype ):
185
- q = get_queue_or_skip ()
186
- skip_if_dtype_not_supported (dtype , q )
187
-
188
- x = dpt .ones (5 , dtype = dtype )
189
- y = dpt .empty_like (x , dtype = "int16" )
190
- with pytest .raises (ValueError ) as excinfo :
191
- callable (x , out = y )
192
- assert re .match ("Output array of type.*is needed" , str (excinfo .value ))
193
-
194
-
195
110
@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
196
111
@pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" ])
197
112
def test_hyper_real_strided (np_call , dpt_call , dtype ):
@@ -270,46 +185,3 @@ def test_hyper_real_special_cases(np_call, dpt_call, dtype):
270
185
271
186
tol = 8 * dpt .finfo (dtype ).resolution
272
187
assert_allclose (dpt .asnumpy (dpt_call (yf )), Y_np , atol = tol , rtol = tol )
273
-
274
-
275
- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
276
- @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
277
- def test_hyper_complex_special_cases_conj_property (np_call , dpt_call , dtype ):
278
- q = get_queue_or_skip ()
279
- skip_if_dtype_not_supported (dtype , q )
280
-
281
- x = [np .nan , np .inf , - np .inf , + 0.0 , - 0.0 , + 1.0 , - 1.0 ]
282
- xc = [complex (* val ) for val in itertools .product (x , repeat = 2 )]
283
-
284
- Xc_np = np .array (xc , dtype = dtype )
285
- Xc = dpt .asarray (Xc_np , dtype = dtype , sycl_queue = q )
286
-
287
- tol = 50 * dpt .finfo (dtype ).resolution
288
- Y = dpt_call (Xc )
289
- Yc = dpt_call (dpt .conj (Xc ))
290
-
291
- dpt .allclose (Y , dpt .conj (Yc ), atol = tol , rtol = tol )
292
-
293
-
294
- @pytest .mark .skipif (
295
- os .name != "posix" , reason = "Known to fail on Windows due to bug in NumPy"
296
- )
297
- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
298
- @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
299
- def test_hyper_complex_special_cases (np_call , dpt_call , dtype ):
300
- q = get_queue_or_skip ()
301
- skip_if_dtype_not_supported (dtype , q )
302
-
303
- x = [np .nan , np .inf , - np .inf , + 0.0 , - 0.0 , + 1.0 , - 1.0 ]
304
- xc = [complex (* val ) for val in itertools .product (x , repeat = 2 )]
305
-
306
- Xc_np = np .array (xc , dtype = dtype )
307
- Xc = dpt .asarray (Xc_np , dtype = dtype , sycl_queue = q )
308
-
309
- with np .errstate (all = "ignore" ):
310
- Ynp = np_call (Xc_np )
311
-
312
- tol = 50 * dpt .finfo (dtype ).resolution
313
- Y = dpt_call (Xc )
314
- assert_allclose (dpt .asnumpy (dpt .real (Y )), np .real (Ynp ), atol = tol , rtol = tol )
315
- assert_allclose (dpt .asnumpy (dpt .imag (Y )), np .imag (Ynp ), atol = tol , rtol = tol )
0 commit comments