Skip to content

Commit 9f6c64d

Browse files
committed
remove redundant tests for multiply
1 parent d282e3a commit 9f6c64d

File tree

1 file changed

+1
-114
lines changed

1 file changed

+1
-114
lines changed

dpctl/tests/elementwise/test_multiply.py

Lines changed: 1 addition & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,14 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import ctypes
18-
1917
import numpy as np
2018
import pytest
2119

22-
import dpctl
2320
import dpctl.tensor as dpt
2421
from dpctl.tensor._type_utils import _can_cast
2522
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2623

27-
from .utils import _all_dtypes, _compare_dtypes, _usm_types
24+
from .utils import _all_dtypes, _compare_dtypes
2825

2926

3027
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
@@ -61,100 +58,6 @@ def test_multiply_dtype_matrix(op1_dtype, op2_dtype):
6158
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
6259

6360

64-
@pytest.mark.parametrize("op1_usm_type", _usm_types)
65-
@pytest.mark.parametrize("op2_usm_type", _usm_types)
66-
def test_multiply_usm_type_matrix(op1_usm_type, op2_usm_type):
67-
get_queue_or_skip()
68-
69-
sz = 128
70-
ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type)
71-
ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type)
72-
73-
r = dpt.multiply(ar1, ar2)
74-
assert isinstance(r, dpt.usm_ndarray)
75-
expected_usm_type = dpctl.utils.get_coerced_usm_type(
76-
(op1_usm_type, op2_usm_type)
77-
)
78-
assert r.usm_type == expected_usm_type
79-
80-
81-
def test_multiply_order():
82-
get_queue_or_skip()
83-
84-
ar1 = dpt.ones((20, 20), dtype="i4", order="C")
85-
ar2 = dpt.ones((20, 20), dtype="i4", order="C")
86-
r1 = dpt.multiply(ar1, ar2, order="C")
87-
assert r1.flags.c_contiguous
88-
r2 = dpt.multiply(ar1, ar2, order="F")
89-
assert r2.flags.f_contiguous
90-
r3 = dpt.multiply(ar1, ar2, order="A")
91-
assert r3.flags.c_contiguous
92-
r4 = dpt.multiply(ar1, ar2, order="K")
93-
assert r4.flags.c_contiguous
94-
95-
ar1 = dpt.ones((20, 20), dtype="i4", order="F")
96-
ar2 = dpt.ones((20, 20), dtype="i4", order="F")
97-
r1 = dpt.multiply(ar1, ar2, order="C")
98-
assert r1.flags.c_contiguous
99-
r2 = dpt.multiply(ar1, ar2, order="F")
100-
assert r2.flags.f_contiguous
101-
r3 = dpt.multiply(ar1, ar2, order="A")
102-
assert r3.flags.f_contiguous
103-
r4 = dpt.multiply(ar1, ar2, order="K")
104-
assert r4.flags.f_contiguous
105-
106-
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
107-
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
108-
r4 = dpt.multiply(ar1, ar2, order="K")
109-
assert r4.strides == (20, -1)
110-
111-
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
112-
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
113-
r4 = dpt.multiply(ar1, ar2, order="K")
114-
assert r4.strides == (-1, 20)
115-
116-
117-
def test_multiply_broadcasting():
118-
get_queue_or_skip()
119-
120-
m = dpt.ones((100, 5), dtype="i4")
121-
v = dpt.arange(1, 6, dtype="i4")
122-
123-
r = dpt.multiply(m, v)
124-
125-
expected = np.multiply(
126-
np.ones((100, 5), dtype="i4"), np.arange(1, 6, dtype="i4")
127-
)
128-
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
129-
130-
r2 = dpt.multiply(v, m)
131-
expected2 = np.multiply(
132-
np.arange(1, 6, dtype="i4"), np.ones((100, 5), dtype="i4")
133-
)
134-
assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all()
135-
136-
137-
@pytest.mark.parametrize("arr_dt", _all_dtypes)
138-
def test_multiply_python_scalar(arr_dt):
139-
q = get_queue_or_skip()
140-
skip_if_dtype_not_supported(arr_dt, q)
141-
142-
X = dpt.ones((10, 10), dtype=arr_dt, sycl_queue=q)
143-
py_ones = (
144-
bool(1),
145-
int(1),
146-
float(1),
147-
complex(1),
148-
np.float32(1),
149-
ctypes.c_int(1),
150-
)
151-
for sc in py_ones:
152-
R = dpt.multiply(X, sc)
153-
assert isinstance(R, dpt.usm_ndarray)
154-
R = dpt.multiply(sc, X)
155-
assert isinstance(R, dpt.usm_ndarray)
156-
157-
15861
@pytest.mark.parametrize("arr_dt", _all_dtypes)
15962
@pytest.mark.parametrize("sc", [bool(1), int(1), float(1), complex(1)])
16063
def test_multiply_python_scalar_gh1219(arr_dt, sc):
@@ -175,22 +78,6 @@ def test_multiply_python_scalar_gh1219(arr_dt, sc):
17578
assert _compare_dtypes(R.dtype, Rnp.dtype, sycl_queue=q)
17679

17780

178-
@pytest.mark.parametrize("dtype", _all_dtypes)
179-
def test_multiply_inplace_python_scalar(dtype):
180-
q = get_queue_or_skip()
181-
skip_if_dtype_not_supported(dtype, q)
182-
X = dpt.ones((10, 10), dtype=dtype, sycl_queue=q)
183-
dt_kind = X.dtype.kind
184-
if dt_kind in "ui":
185-
X *= int(1)
186-
elif dt_kind == "f":
187-
X *= float(1)
188-
elif dt_kind == "c":
189-
X *= complex(1)
190-
elif dt_kind == "b":
191-
X *= bool(1)
192-
193-
19481
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
19582
@pytest.mark.parametrize("op2_dtype", _all_dtypes)
19683
def test_multiply_inplace_dtype_matrix(op1_dtype, op2_dtype):

0 commit comments

Comments
 (0)