14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
17
- import ctypes
18
-
19
17
import numpy as np
20
18
import pytest
21
19
22
- import dpctl
23
20
import dpctl .tensor as dpt
24
21
from dpctl .tensor ._type_utils import _can_cast
25
22
from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
26
23
27
- from .utils import _all_dtypes , _compare_dtypes , _usm_types
24
+ from .utils import _all_dtypes , _compare_dtypes
28
25
29
26
30
27
@pytest .mark .parametrize ("op1_dtype" , _all_dtypes )
@@ -61,100 +58,6 @@ def test_multiply_dtype_matrix(op1_dtype, op2_dtype):
61
58
assert (dpt .asnumpy (r ) == expected .astype (r .dtype )).all ()
62
59
63
60
64
- @pytest .mark .parametrize ("op1_usm_type" , _usm_types )
65
- @pytest .mark .parametrize ("op2_usm_type" , _usm_types )
66
- def test_multiply_usm_type_matrix (op1_usm_type , op2_usm_type ):
67
- get_queue_or_skip ()
68
-
69
- sz = 128
70
- ar1 = dpt .ones (sz , dtype = "i4" , usm_type = op1_usm_type )
71
- ar2 = dpt .ones_like (ar1 , dtype = "i4" , usm_type = op2_usm_type )
72
-
73
- r = dpt .multiply (ar1 , ar2 )
74
- assert isinstance (r , dpt .usm_ndarray )
75
- expected_usm_type = dpctl .utils .get_coerced_usm_type (
76
- (op1_usm_type , op2_usm_type )
77
- )
78
- assert r .usm_type == expected_usm_type
79
-
80
-
81
- def test_multiply_order ():
82
- get_queue_or_skip ()
83
-
84
- ar1 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "C" )
85
- ar2 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "C" )
86
- r1 = dpt .multiply (ar1 , ar2 , order = "C" )
87
- assert r1 .flags .c_contiguous
88
- r2 = dpt .multiply (ar1 , ar2 , order = "F" )
89
- assert r2 .flags .f_contiguous
90
- r3 = dpt .multiply (ar1 , ar2 , order = "A" )
91
- assert r3 .flags .c_contiguous
92
- r4 = dpt .multiply (ar1 , ar2 , order = "K" )
93
- assert r4 .flags .c_contiguous
94
-
95
- ar1 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "F" )
96
- ar2 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "F" )
97
- r1 = dpt .multiply (ar1 , ar2 , order = "C" )
98
- assert r1 .flags .c_contiguous
99
- r2 = dpt .multiply (ar1 , ar2 , order = "F" )
100
- assert r2 .flags .f_contiguous
101
- r3 = dpt .multiply (ar1 , ar2 , order = "A" )
102
- assert r3 .flags .f_contiguous
103
- r4 = dpt .multiply (ar1 , ar2 , order = "K" )
104
- assert r4 .flags .f_contiguous
105
-
106
- ar1 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ]
107
- ar2 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ]
108
- r4 = dpt .multiply (ar1 , ar2 , order = "K" )
109
- assert r4 .strides == (20 , - 1 )
110
-
111
- ar1 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ].mT
112
- ar2 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ].mT
113
- r4 = dpt .multiply (ar1 , ar2 , order = "K" )
114
- assert r4 .strides == (- 1 , 20 )
115
-
116
-
117
- def test_multiply_broadcasting ():
118
- get_queue_or_skip ()
119
-
120
- m = dpt .ones ((100 , 5 ), dtype = "i4" )
121
- v = dpt .arange (1 , 6 , dtype = "i4" )
122
-
123
- r = dpt .multiply (m , v )
124
-
125
- expected = np .multiply (
126
- np .ones ((100 , 5 ), dtype = "i4" ), np .arange (1 , 6 , dtype = "i4" )
127
- )
128
- assert (dpt .asnumpy (r ) == expected .astype (r .dtype )).all ()
129
-
130
- r2 = dpt .multiply (v , m )
131
- expected2 = np .multiply (
132
- np .arange (1 , 6 , dtype = "i4" ), np .ones ((100 , 5 ), dtype = "i4" )
133
- )
134
- assert (dpt .asnumpy (r2 ) == expected2 .astype (r2 .dtype )).all ()
135
-
136
-
137
- @pytest .mark .parametrize ("arr_dt" , _all_dtypes )
138
- def test_multiply_python_scalar (arr_dt ):
139
- q = get_queue_or_skip ()
140
- skip_if_dtype_not_supported (arr_dt , q )
141
-
142
- X = dpt .ones ((10 , 10 ), dtype = arr_dt , sycl_queue = q )
143
- py_ones = (
144
- bool (1 ),
145
- int (1 ),
146
- float (1 ),
147
- complex (1 ),
148
- np .float32 (1 ),
149
- ctypes .c_int (1 ),
150
- )
151
- for sc in py_ones :
152
- R = dpt .multiply (X , sc )
153
- assert isinstance (R , dpt .usm_ndarray )
154
- R = dpt .multiply (sc , X )
155
- assert isinstance (R , dpt .usm_ndarray )
156
-
157
-
158
61
@pytest .mark .parametrize ("arr_dt" , _all_dtypes )
159
62
@pytest .mark .parametrize ("sc" , [bool (1 ), int (1 ), float (1 ), complex (1 )])
160
63
def test_multiply_python_scalar_gh1219 (arr_dt , sc ):
@@ -175,22 +78,6 @@ def test_multiply_python_scalar_gh1219(arr_dt, sc):
175
78
assert _compare_dtypes (R .dtype , Rnp .dtype , sycl_queue = q )
176
79
177
80
178
- @pytest .mark .parametrize ("dtype" , _all_dtypes )
179
- def test_multiply_inplace_python_scalar (dtype ):
180
- q = get_queue_or_skip ()
181
- skip_if_dtype_not_supported (dtype , q )
182
- X = dpt .ones ((10 , 10 ), dtype = dtype , sycl_queue = q )
183
- dt_kind = X .dtype .kind
184
- if dt_kind in "ui" :
185
- X *= int (1 )
186
- elif dt_kind == "f" :
187
- X *= float (1 )
188
- elif dt_kind == "c" :
189
- X *= complex (1 )
190
- elif dt_kind == "b" :
191
- X *= bool (1 )
192
-
193
-
194
81
@pytest .mark .parametrize ("op1_dtype" , _all_dtypes )
195
82
@pytest .mark .parametrize ("op2_dtype" , _all_dtypes )
196
83
def test_multiply_inplace_dtype_matrix (op1_dtype , op2_dtype ):
0 commit comments