Skip to content

Commit 7c53289

Browse files
committed
Add tests to cover the new method
1 parent d49116b commit 7c53289

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

dpnp/tests/test_ndarray.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import dpnp
1212

1313
from .helper import (
14+
generate_random_numpy_array,
1415
get_abs_array,
1516
get_all_dtypes,
1617
get_complex_dtypes,
@@ -107,6 +108,64 @@ def test_strides(self):
107108
assert xp.full_like(a, fill_value=6) not in a
108109

109110

111+
class TestToFile:
112+
def _create_data(self):
113+
x = generate_random_numpy_array((2, 4, 3), dtype=complex)
114+
x[0, :, 1] = [numpy.nan, numpy.inf, -numpy.inf, numpy.nan]
115+
return dpnp.array(x)
116+
117+
@pytest.fixture(params=["string", "path_obj"])
118+
def tmp_filename(self, tmp_path, request):
119+
# This fixture covers two cases:
120+
# one where the filename is a string and
121+
# another where it is a pathlib object
122+
filename = tmp_path / "file"
123+
if request.param == "string":
124+
filename = str(filename)
125+
yield filename
126+
127+
def test_roundtrip_file(self, tmp_filename):
128+
a = self._create_data()
129+
130+
with open(tmp_filename, "wb") as f:
131+
a.tofile(f)
132+
133+
# reconstruct the array back from the file
134+
with open(tmp_filename, "rb") as f:
135+
b = dpnp.fromfile(f, dtype=a.dtype)
136+
assert_array_equal(b, a.asnumpy().flat)
137+
138+
def test_roundtrip(self, tmp_filename):
139+
a = self._create_data()
140+
141+
a.tofile(tmp_filename)
142+
b = dpnp.fromfile(tmp_filename, dtype=a.dtype)
143+
assert_array_equal(b, a.asnumpy().flat)
144+
145+
def test_sep(self, tmp_filename):
146+
a = dpnp.array([1.51, 2, 3.51, 4])
147+
148+
with open(tmp_filename, "w") as f:
149+
a.tofile(f, sep=",")
150+
151+
# reconstruct the array
152+
with open(tmp_filename, "r") as f:
153+
s = f.read()
154+
b = dpnp.array([float(p) for p in s.split(",")], dtype=a.dtype)
155+
assert_array_equal(a, b.asnumpy())
156+
157+
def test_format(self, tmp_filename):
158+
a = dpnp.array([1.51, 2, 3.51, 4])
159+
160+
with open(tmp_filename, "w") as f:
161+
a.tofile(f, sep=",", format="%.2f")
162+
163+
# reconstruct the array as a string
164+
with open(tmp_filename, "r") as f:
165+
s = f.read()
166+
assert_equal(s, "1.51,2.00,3.51,4.00")
167+
168+
110169
class TestToList:
111170
@pytest.mark.parametrize(
112171
"data", [[1, 2], [[1, 2], [3, 4]]], ids=["1d", "2d"]

0 commit comments

Comments
 (0)