Skip to content

Commit 233c5af

Browse files
committed
Add tests
1 parent 206248c commit 233c5af

File tree

6 files changed

+23
-12
lines changed

6 files changed

+23
-12
lines changed

dpnp/backend/extensions/vm/erf_funcs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,6 @@ void init_erf_funcs(py::module_ m)
193193
m, "_erfcx",
194194
"Call `erfcx` function from OneMKL VM library to compute the scaled "
195195
"complementary error function value of vector elements",
196-
impl::erfc_contig_dispatch_vector);
196+
impl::erfcx_contig_dispatch_vector);
197197
}
198198
} // namespace dpnp::extensions::vm

dpnp/tests/test_special.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@with_requires("scipy")
17-
@pytest.mark.parametrize("func", ["erf", "erfc"])
17+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
1818
class TestCommon:
1919
@pytest.mark.parametrize(
2020
"dt", get_all_dtypes(no_none=True, no_float16=False, no_complex=True)
@@ -65,19 +65,31 @@ def test_complex(self, func, dt):
6565

6666
class TestConsistency:
6767

68-
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
69-
def test_erfc(self):
68+
def _check_variant_func(self, func, other_func, rtol, atol=0):
7069
# TODO: replace with dpnp.random.RandomState, once pareto is added
7170
rng = numpy.random.RandomState(1234)
7271
n = 10000
7372
a = rng.pareto(0.02, n) * (2 * rng.randint(0, 2, n) - 1)
7473
a = dpnp.array(a)
74+
a = a[::-1]
7575

76-
res = 1 - dpnp.scipy.special.erf(a)
76+
res = other_func(a)
7777
mask = dpnp.isfinite(res)
7878
a = a[mask]
7979

80-
tol = 8 * dpnp.finfo(a).resolution
81-
assert dpnp.allclose(
82-
dpnp.scipy.special.erfc(a), res[mask], rtol=tol, atol=tol
80+
assert dpnp.allclose(func(a), res[mask], rtol=rtol, atol=atol)
81+
82+
def test_erfc(self):
83+
self._check_variant_func(
84+
dpnp.special.erfc,
85+
lambda z: 1 - dpnp.special.erf(z),
86+
rtol=1e-12,
87+
atol=1e-14,
88+
)
89+
90+
def test_erfcx(self):
91+
self._check_variant_func(
92+
dpnp.special.erfcx,
93+
lambda z: dpnp.exp(z * z) * dpnp.special.erfc(z),
94+
rtol=1e-12,
8395
)

dpnp/tests/test_strides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def test_reduce_hypot(dtype, stride):
167167

168168

169169
@with_requires("scipy")
170-
@pytest.mark.parametrize("func", ["erf", "erfc"])
170+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
171171
@pytest.mark.parametrize("stride", [2, -1, -3])
172172
def test_erf_funcs(func, stride):
173173
import scipy.special

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1488,7 +1488,7 @@ def test_interp(device, left, right, period):
14881488
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
14891489

14901490

1491-
@pytest.mark.parametrize("func", ["erf", "erfc"])
1491+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
14921492
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
14931493
def test_erf_funcs(func, device):
14941494
x = dpnp.linspace(-3, 3, num=5, device=device)

dpnp/tests/test_usm_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1296,7 +1296,7 @@ def test_choose(usm_type_x, usm_type_ind):
12961296
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])
12971297

12981298

1299-
@pytest.mark.parametrize("func", ["erf", "erfc"])
1299+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
13001300
@pytest.mark.parametrize("usm_type", list_of_usm_types)
13011301
def test_erf_funcs(func, usm_type):
13021302
x = dpnp.linspace(-3, 3, num=5, usm_type=usm_type)

dpnp/tests/third_party/cupyx/scipy_tests/special_tests/test_erf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def test_erf(self):
2727
def test_erfc(self):
2828
self.check_unary("erfc")
2929

30-
@pytest.mark.skip("erfcx() is not supported yet")
3130
@testing.with_requires("scipy>=1.16.0")
3231
def test_erfcx(self):
3332
self.check_unary("erfcx")

0 commit comments

Comments
 (0)