diff --git a/tests/test_xarray_adios2.py b/tests/test_xarray_adios2.py index c1ac651..aafbe89 100644 --- a/tests/test_xarray_adios2.py +++ b/tests/test_xarray_adios2.py @@ -1,8 +1,5 @@ from __future__ import annotations -import os -from typing import Any - import adios2py import numpy as np import pytest @@ -70,11 +67,14 @@ def test_filename_4(tmp_path): return filename -def _open_dataset(filename: os.PathLike[Any], *, decode: bool = False) -> xr.Dataset: - ds = xr.open_dataset(filename) - if decode: - ds = _decode_dataset(ds) - return ds +@pytest.fixture +def ds_pfd_raw() -> xr.Dataset: + return xr.open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + + +@pytest.fixture +def ds_pfd_moments_raw() -> xr.Dataset: + return xr.open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp") def _decode_dataset(ds: xr.Dataset) -> xr.Dataset: @@ -86,77 +86,80 @@ def _decode_dataset(ds: xr.Dataset) -> xr.Dataset: ) -def test_open_dataset(): - ds_decoded = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp", decode=True) - assert "jx_ec" in ds_decoded - assert ds_decoded.coords.keys() == set({"x", "y", "z"}) - assert ds_decoded.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408 +@pytest.fixture +def ds_pfd_decoded(ds_pfd_raw) -> xr.Dataset: + return _decode_dataset(ds_pfd_raw) + + +@pytest.fixture +def ds_pfd_moments_decoded(ds_pfd_moments_raw) -> xr.Dataset: + return _decode_dataset(ds_pfd_moments_raw) + + +def test_open_dataset(ds_pfd_decoded): + assert "jx_ec" in ds_pfd_decoded + assert ds_pfd_decoded.coords.keys() == set({"x", "y", "z"}) + assert ds_pfd_decoded.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408 assert np.allclose( - ds_decoded.jx_ec.z.data, np.linspace(-25.6, 25.6, 512, endpoint=False).data + ds_pfd_decoded.jx_ec.z.data, np.linspace(-25.6, 25.6, 512, endpoint=False).data ) -def test_component(): - ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - ds_decoded = _decode_dataset(ds_raw) - assert np.all(ds_raw.jeh.isel(dim_1_9=0).data == ds_decoded.jx_ec.data) +def test_component(ds_pfd_raw, ds_pfd_decoded): + assert np.all(ds_pfd_raw.jeh.isel(dim_1_9=0).data == ds_pfd_decoded.jx_ec.data) -def test_selection(): - ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - ds_decoded = _decode_dataset(ds_raw) - assert np.all( - ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)).data - == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data - ) +def test_selection(ds_pfd_raw, ds_pfd_decoded): + data_raw = ds_pfd_raw.jeh.isel( + dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40) + ).data + data_decoded = ds_pfd_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data + assert np.all(data_raw == data_decoded) -def test_nbytes(): - ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - ds_decoded = _decode_dataset(ds_raw) - assert ds_decoded.nbytes == ds_decoded.nbytes +def _get_nbytes(ds: xr.Dataset) -> int: + return sum(arr.nbytes for arr in ds.data_vars.values()) -def test_missing_length(): - ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") +def test_nbytes(ds_pfd_raw, ds_pfd_decoded): + assert _get_nbytes(ds_pfd_raw) == _get_nbytes(ds_pfd_decoded) + + +def test_missing_length(ds_pfd_raw): with pytest.raises(ValueError, match=r".*length.*"): pscpy.decode_psc( - ds_raw, + ds_pfd_raw, species_names=["e", "i"], corner=[0, -6.4, -25.6], ) -def test_missing_corner(): - ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") +def test_missing_corner(ds_pfd_raw): with pytest.raises(ValueError, match=r".*corner.*"): pscpy.decode_psc( - ds_raw, + ds_pfd_raw, species_names=["e", "i"], length=[1, 12.8, 51.2], ) -def test_computed(): - ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - ds_decoded = _decode_dataset(ds_raw) - ds_raw = ds_raw.assign(jx=ds_raw.jeh.isel(dim_1_9=0)) - assert np.all(ds_raw.jx.data == ds_decoded.jx_ec.data) +def test_computed(ds_pfd_raw, ds_pfd_decoded): + ds_pfd_raw = ds_pfd_raw.assign(jx=ds_pfd_raw.jeh.isel(dim_1_9=0)) + assert np.all(ds_pfd_raw.jx.data == ds_pfd_decoded.jx_ec.data) -def test_computed_via_lambda(): - ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - ds_decoded = _decode_dataset(ds_raw) - ds_raw = ds_raw.assign(jx=lambda ds: ds.jeh.isel(dim_1_9=0)) - assert np.all(ds_raw.jx.data == ds_decoded.jx_ec.data) +def test_computed_via_lambda(ds_pfd_raw, ds_pfd_decoded): + ds_pfd_raw = ds_pfd_raw.assign(jx=lambda ds: ds.jeh.isel(dim_1_9=0)) + assert np.all(ds_pfd_raw.jx.data == ds_pfd_decoded.jx_ec.data) -def test_pfd_moments(): - ds_raw = _open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp") - ds_decoded = _decode_dataset(ds_raw) - assert "all_1st" in ds_raw - assert "rho_i" in ds_decoded - assert np.all(ds_decoded.rho_i.data == ds_raw.all_1st.isel(dim_1_26=13).data) +def test_pfd_moments(ds_pfd_moments_raw, ds_pfd_moments_decoded): + assert "all_1st" in ds_pfd_moments_raw + assert "rho_i" in ds_pfd_moments_decoded + assert np.all( + ds_pfd_moments_decoded.rho_i.data + == ds_pfd_moments_raw.all_1st.isel(dim_1_26=13).data + ) def test_open_dataset_steps(test_filename):