diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 794eeca..97f1861 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,12 +40,12 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.12"] + python-version: ["3.10", "3.12"] runs-on: [ubuntu-latest, macos-latest, windows-latest] - include: - - python-version: pypy-3.10 - runs-on: ubuntu-latest + # include: + # - python-version: pypy-3.10 + # runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/CMakeLists.txt b/CMakeLists.txt deleted file mode 100644 index 3e680b1..0000000 --- a/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -cmake_minimum_required(VERSION 3.15...3.26) -project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) diff --git a/pyproject.toml b/pyproject.toml index 455edaa..97e855c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,6 @@ [build-system] -requires = ["scikit-build-core"] -build-backend = "scikit_build_core.build" - +build-backend = "setuptools.build_meta" +requires = ["setuptools>=42", "setuptools-scm>=7"] [project] name = "xarray-pschdf5" @@ -30,21 +29,11 @@ classifiers = [ "Typing :: Typed", ] dynamic = ["version"] -dependencies = [ - "pugixml >= 0.5.0", - "h5py", - "xarray", -] +dependencies = ["pugixml >= 0.5.0", "h5py", "xarray"] [project.optional-dependencies] -test = [ - "pytest >=6", - "pytest-cov >=3", -] -dev = [ - "pytest >=6", - "pytest-cov >=3", -] +test = ["pytest >=6", "pytest-cov >=3", "typing-extensions"] +dev = ["pytest >=6", "pytest-cov >=3", "typing-extensions"] docs = [ "sphinx>=7.0", "myst_parser>=0.13", @@ -60,7 +49,7 @@ Discussions = "https://github.com/psc-code/xarray-pschdf5/discussions" Changelog = "https://github.com/psc-code/xarray-pschdf5/releases" [project.entry-points."xarray.backends"] -pschdf5 = "xarray_pschdf5:PscHdf5Entrypoint" +pschdf5 = "xarray_pschdf5.pschdf5_backend:PscHdf5Entrypoint" [tool.scikit-build] minimum-version = "0.4" @@ -83,21 +72,14 @@ test-skip = ["*universal2:arm64"] minversion = "6.0" addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] xfail_strict = true -filterwarnings = [ - "error", -] +filterwarnings = ["error"] log_cli_level = "INFO" -testpaths = [ - "tests", -] +testpaths = ["tests"] [tool.coverage] run.source = ["xarray_pschdf5"] -report.exclude_also = [ - '\.\.\.', - 'if typing.TYPE_CHECKING:', -] +report.exclude_also = ['\.\.\.', 'if typing.TYPE_CHECKING:'] [tool.mypy] files = ["src", "tests"] @@ -120,33 +102,33 @@ src = ["src"] [tool.ruff.lint] extend-select = [ - "B", # flake8-bugbear - "I", # isort - "ARG", # flake8-unused-arguments - "C4", # flake8-comprehensions - "EM", # flake8-errmsg - "ICN", # flake8-import-conventions - "G", # flake8-logging-format - "PGH", # pygrep-hooks - "PIE", # flake8-pie - "PL", # pylint - "PT", # flake8-pytest-style - "PTH", # flake8-use-pathlib - "RET", # flake8-return - "RUF", # Ruff-specific - "SIM", # flake8-simplify - "T20", # flake8-print - "UP", # pyupgrade - "YTT", # flake8-2020 - "EXE", # flake8-executable - "NPY", # NumPy specific rules - "PD", # pandas-vet + "B", # flake8-bugbear + "I", # isort + "ARG", # flake8-unused-arguments + "C4", # flake8-comprehensions + "EM", # flake8-errmsg + "ICN", # flake8-import-conventions + "G", # flake8-logging-format + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "RET", # flake8-return + "RUF", # Ruff-specific + "SIM", # flake8-simplify + "T20", # flake8-print + "UP", # pyupgrade + "YTT", # flake8-2020 + "EXE", # flake8-executable + "NPY", # NumPy specific rules + "PD", # pandas-vet ] ignore = [ - "PLR09", # Too many <...> - "PLR2004", # Magic value used in comparison - "ISC001", # Conflicts with formatter - "C408", # like my dict() calls + "PLR09", # Too many <...> + "PLR2004", # Magic value used in comparison + "ISC001", # Conflicts with formatter + "C408", # like my dict() calls "RET504", ] isort.required-imports = ["from __future__ import annotations"] diff --git a/src/main.cpp b/src/main.cpp deleted file mode 100644 index 2b8ed18..0000000 --- a/src/main.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include - -int add(int i, int j) { return i + j; } - -namespace py = pybind11; - -PYBIND11_MODULE(_core, m) { - m.doc() = R"pbdoc( - Pybind11 example plugin - ----------------------- - .. currentmodule:: python_example - .. autosummary:: - :toctree: _generate - add - subtract - )pbdoc"; - - m.def("add", &add, R"pbdoc( - Add two numbers - Some other explanation about the add function. - )pbdoc"); - - m.def( - "subtract", [](int i, int j) { return i - j; }, R"pbdoc( - Subtract two numbers - Some other explanation about the subtract function. - )pbdoc"); -} diff --git a/src/xarray_pschdf5/__init__.py b/src/xarray_pschdf5/__init__.py index 6eea410..d5d160d 100644 --- a/src/xarray_pschdf5/__init__.py +++ b/src/xarray_pschdf5/__init__.py @@ -6,6 +6,10 @@ from __future__ import annotations +import pathlib + from ._version import version as __version__ -__all__ = ["__version__"] +sample_dir = pathlib.Path(__file__).parent / "sample" + +__all__ = ["__version__", "sample_dir"] diff --git a/src/xarray_pschdf5/pschdf5_backend.py b/src/xarray_pschdf5/pschdf5_backend.py index dc9c430..773a946 100644 --- a/src/xarray_pschdf5/pschdf5_backend.py +++ b/src/xarray_pschdf5/pschdf5_backend.py @@ -1,22 +1,33 @@ from __future__ import annotations +import dataclasses import os +import pathlib from collections import OrderedDict +from collections.abc import Iterable from typing import Any, ClassVar import h5py import numpy as np import xarray as xr from pugixml import pugi +from typing_extensions import override from xarray.backends import BackendEntrypoint class PscHdf5Entrypoint(BackendEntrypoint): + @override def open_dataset( self, filename_or_obj, *, - drop_variables=None, + mask_and_scale: bool = True, + decode_times: bool = True, + concat_characters: bool = True, + decode_coords: bool = True, + drop_variables: str | Iterable[str] | None = None, + use_cftime: bool | None = None, + decode_timedelta: bool | None = None, # other backend specific keyword arguments # `chunks` and `cache` DO NOT go here, they are handled by xarray ): @@ -24,70 +35,167 @@ def open_dataset( open_dataset_parameters: ClassVar[Any] = ["filename_or_obj", "drop_variables"] - def guess_can_open(self, filename_or_obj): - if filename_or_obj.endswith(".xdmf"): - return True + def guess_can_open(self, filename_or_obj) -> bool: + if not isinstance(filename_or_obj, str | os.PathLike): + return False - return False + filename_or_obj = pathlib.Path(filename_or_obj) + + return filename_or_obj.suffix == ".xdmf" description = "XArray reader for PSC HDF5 data" url = "https://link_to/your_backend/documentation" # FIXME +@dataclasses.dataclass +class VariableInfo: + """Class for keeping track of per-variable info.""" + + shape: tuple + dims: tuple + times: list[np.datetime64] = dataclasses.field(default_factory=list) + paths: list[str] = dataclasses.field(default_factory=list) + + def pschdf5_open_dataset(filename_or_obj, *, drop_variables=None): - dirname, basename = os.path.split(filename_or_obj) + filename_or_obj = pathlib.Path(filename_or_obj) + if isinstance(drop_variables, str): + drop_variables = [drop_variables] + elif drop_variables is None: + drop_variables = [] + drop_variables = set(drop_variables) + + dirname = filename_or_obj.parent meta = read_xdmf(filename_or_obj) - meta["run"] = "RUN" - grids = meta["grids"] - vars = dict() - assert len(grids) == 1 - for _, grid in grids.items(): + var_infos = {} + n_times = len(meta) + for spatial in meta: + grids = spatial["grids"] + assert len(grids) == 1 + _, grid = next(iter(grids.items())) for fldname, fld in grid["fields"].items(): if fldname in drop_variables: continue - # data_dims = fld["dims"] - data_path = fld["path"] - h5_filename, h5_path = fld["path"].split(":") - h5_file = h5py.File(dirname + "/" + h5_filename) - data = h5_file[h5_path][:].T - - data_attrs = dict(path=data_path) - vars[fldname] = xr.DataArray( - data=data, dims=["x", "y", "z"], attrs=data_attrs + if fldname not in var_infos: + var_infos[fldname] = VariableInfo( + shape=(n_times,) + fld["dims"], dims=_make_dims(fld) + ) + time = np.datetime64("2000-01-01T00:00:00") + np.timedelta64( + int(spatial["time"]) * 1000000000, "ns" ) + var_infos[fldname].times.append(time) + var_infos[fldname].paths.append(fld["path"]) - coords = { - "xyz"[d]: make_crd( - grid["topology"]["dims"][d], - grid["geometry"]["origin"][d], - grid["geometry"]["spacing"][d], - ) - for d in range(3) - } + _, var_info = next(iter(var_infos.items())) + coords = _make_coords(grid, var_info.times) + + vars = {} + for name, info in var_infos.items(): + da = xr.DataArray(data=np.empty(info.shape), dims=info.dims) + for it, path in enumerate(info.paths): + h5_filename, h5_path = path.split(":") + h5_file = h5py.File(dirname / h5_filename) + da[it, :, :] = h5_file[h5_path] - attrs = dict(run=meta["run"], time=meta["time"]) + vars[name] = da + attrs = {} ds = xr.Dataset(vars, coords=coords, attrs=attrs) # ds.set_close(my_close_method) return ds +def _make_dims(fld): + match len(fld["dims"]): + case 2: + return ("time", "lats", "longs") + case 3: + return ("time", "z", "y", "x") + + +def _make_coords(grid, times): # ("topology"), grid["geometry"]): + dims = grid["topology"]["dims"] + match grid["topology"]["type"]: + case "3DCoRectMesh": + coords = { + "xyz"[d]: ( + "xyz"[d], + make_crd( + dims[d], + grid["geometry"]["origin"][d], + grid["geometry"]["spacing"][d], + ), + ) + for d in range(3) + } + case "2DSMesh": + coords = { + "lats": ("lats", np.linspace(90, -90, dims[0])), + "longs": ("longs", np.linspace(-180, 180, dims[1])), + "colats": ("lats", np.linspace(0, 180, dims[0])), + "mlts": ("longs", np.linspace(0, 24, dims[1])), + } + + coords["time"] = ("time", np.asarray(times)) + return coords + + def make_crd(dim, origin, spacing): return origin + np.arange(0.5, dim) * spacing -def read_xdmf(filename): - doc = pugi.XMLDocument() - result = doc.load_file(filename) - if not result: - raise f"parse error: status={result.status} description={result.description()}" +def _parse_dimensions_attr(node): + attr = node.attribute("Dimensions") + return tuple(int(d) for d in attr.value().split(" ")) - grid_collection = doc.child("Xdmf").child("Domain").child("Grid") - assert grid_collection.attribute("GridType").value() == "Collection" + +def _parse_geometry_origin_dxdydz(geometry): + geo = dict() + for child in geometry.children(): + if child.attribute("Name").value() == "Origin": + geo["origin"] = np.asarray( + [float(x) for x in child.text().as_string().split(" ")] + ) + + if child.attribute("Name").value() == "Spacing": + geo["spacing"] = np.asarray( + [float(x) for x in child.text().as_string().split(" ")] + ) + return geo + + +def _parse_geometry_xyz(geometry): + geo = dict() + data_item = geometry.child("DataItem") + assert data_item.attribute("Format").value() == "XML" + dims = _parse_dimensions_attr(data_item) + data = np.loadtxt(data_item.text().as_string().splitlines()) + geo = {"data_item": data.reshape(dims)} + return geo + + +def _parse_temporal_collection(filename, grid_collection): + temporal = [] + for node in grid_collection.children(): + href = node.attribute("href").value() + doc = pugi.XMLDocument() + result = doc.load_file(filename.parent / href) + if not result: + msg = f"parse error: status={result.status} description={result.description()}" + raise RuntimeError(msg) + + temporal.append( + _parse_spatial_collection(doc.child("Xdmf").child("Domain").child("Grid")) + ) + + return temporal + + +def _parse_spatial_collection(grid_collection): grid_time = grid_collection.child("Time") assert grid_time.attribute("Type").value() == "Single" time = grid_time.attribute("Value").value() @@ -99,28 +207,18 @@ def read_xdmf(filename): grid = {} grid_name = node.attribute("Name").value() topology = node.child("Topology") - # assert topology.attribute('TopologyType').value() == '3DCoRectMesh' - dims = topology.attribute("Dimensions").value() - dims = np.asarray([int(d) - 1 for d in dims.split(" ")])[::-1] + dims = _parse_dimensions_attr(topology) grid["topology"] = { "type": topology.attribute("TopologyType").value(), "dims": dims, } geometry = node.child("Geometry") - assert geometry.attribute("GeometryType").value() == "Origin_DxDyDz" - - grid["geometry"] = dict() - for child in geometry.children(): - if child.attribute("Name").value() == "Origin": - grid["geometry"]["origin"] = np.asarray( - [float(x) for x in child.text().as_string().split(" ")] - )[::-1] - - if child.attribute("Name").value() == "Spacing": - grid["geometry"]["spacing"] = np.asarray( - [float(x) for x in child.text().as_string().split(" ")] - )[::-1] + match geometry.attribute("GeometryType").value(): + case "Origin_DxDyDz": + grid["geometry"] = _parse_geometry_origin_dxdydz(geometry) + case "XYZ": + grid["geometry"] = _parse_geometry_xyz(geometry) flds = OrderedDict() for child in node.children(): @@ -129,15 +227,30 @@ def read_xdmf(filename): fld = child.attribute("Name").value() item = child.child("DataItem") - fld_dims = np.asarray( - [int(d) for d in item.attribute("Dimensions").value().split(" ")] - )[::-1] + fld_dims = _parse_dimensions_attr(item) assert np.all(fld_dims == dims) assert item.attribute("Format").value() == "HDF" - path = item.text().as_string() + path = item.text().as_string().strip() flds[fld] = {"path": path, "dims": dims} grid["fields"] = flds rv["grids"][grid_name] = grid return rv + + +def read_xdmf(filename): + doc = pugi.XMLDocument() + result = doc.load_file(filename) + if not result: + raise f"parse error: status={result.status} description={result.description()}" + + grid_collection = doc.child("Xdmf").child("Domain").child("Grid") + assert grid_collection.attribute("GridType").value() == "Collection" + match grid_collection.attribute("CollectionType").value(): + case "Spatial": + return [_parse_spatial_collection(grid_collection)] + case "Temporal": + return _parse_temporal_collection(filename, grid_collection) + case _: + raise RuntimeError() diff --git a/src/xarray_pschdf5/sample/sample_xdmf.iof.004020.xdmf b/src/xarray_pschdf5/sample/sample_xdmf.iof.004020.xdmf new file mode 100644 index 0000000..45c9688 --- /dev/null +++ b/src/xarray_pschdf5/sample/sample_xdmf.iof.004020.xdmf @@ -0,0 +1,11197 @@ + + + + + + + diff --git a/src/xarray_pschdf5/sample/sample_xdmf.iof.004020_p000000.h5 b/src/xarray_pschdf5/sample/sample_xdmf.iof.004020_p000000.h5 new file mode 100644 index 0000000..8d5fd14 Binary files /dev/null and b/src/xarray_pschdf5/sample/sample_xdmf.iof.004020_p000000.h5 differ diff --git a/src/xarray_pschdf5/sample/sample_xdmf.iof.004040.xdmf b/src/xarray_pschdf5/sample/sample_xdmf.iof.004040.xdmf new file mode 100644 index 0000000..5c5cc10 --- /dev/null +++ b/src/xarray_pschdf5/sample/sample_xdmf.iof.004040.xdmf @@ -0,0 +1,11197 @@ + + + + + + + diff --git a/src/xarray_pschdf5/sample/sample_xdmf.iof.xdmf b/src/xarray_pschdf5/sample/sample_xdmf.iof.xdmf new file mode 100644 index 0000000..feb44b3 --- /dev/null +++ b/src/xarray_pschdf5/sample/sample_xdmf.iof.xdmf @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/tests/test_compiled.py b/tests/test_compiled.py deleted file mode 100644 index 9e5d96a..0000000 --- a/tests/test_compiled.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -import xarray_pschdf5._core as m - - -def test_add(): - assert m.add(2, 3) == 5 - - -def test_subtract(): - assert m.subtract(7, 5) == 2 diff --git a/tests/test_read.py b/tests/test_read.py new file mode 100644 index 0000000..83d00aa --- /dev/null +++ b/tests/test_read.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import xarray as xr + +import xarray_pschdf5 + + +def test_read_sample(): + xr.open_dataset(xarray_pschdf5.sample_dir / "sample_xdmf.iof.004020.xdmf") + + +def test_read_sample_all(): + ds = xr.open_dataset(xarray_pschdf5.sample_dir / "sample_xdmf.iof.xdmf") + print(ds) + assert set(ds.coords) == set({"time", "lats", "longs", "colats", "mlts"}) + assert ds.pot.dims == ("time", "lats", "longs")