diff --git a/docs/release-notes.md b/docs/release-notes.md index aa8001a3..9234b933 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,8 +1,10 @@ # Release notes ## 0.2.5 -- Fix numpy 2.4 and obspy 1.4.2 incompatibilities (@atrabatto). - Add SampleCoordinate for more SEED-like coordinates (@atrabattoni). +- Add `create_dirs` to `.to_netcdf` methods to create intermediate directories (@aurelienfalco). +- Fix numpy 2.4 and obspy 1.4.2 incompatibilities (@atrabatto). + ## 0.2.4 - Add StreamWriter to write long time series to miniSEED (@marbail). diff --git a/tests/test_dataarray.py b/tests/test_dataarray.py index 4a251f7e..acdabb35 100644 --- a/tests/test_dataarray.py +++ b/tests/test_dataarray.py @@ -458,6 +458,16 @@ def test_io_attrs(self): assert result.attrs == attrs assert result.equals(da) + def test_io_create_dirs(self): + da = xd.DataArray(np.arange(3)) + with TemporaryDirectory() as dirpath: + path = os.path.join(dirpath, "subdir", "tmp.nc") + with pytest.raises(FileNotFoundError, match="No such file or directory"): + da.to_netcdf(path) + da.to_netcdf(path, create_dirs=True) + result = xd.DataArray.from_netcdf(path) + assert result.equals(da) + def test_ufunc(self): da = wavelet_wavefronts() result = np.add(da, 1) diff --git a/tests/test_datacollection.py b/tests/test_datacollection.py index 7ebd2c0f..d70cf22c 100644 --- a/tests/test_datacollection.py +++ b/tests/test_datacollection.py @@ -68,6 +68,23 @@ def test_io(self): result = xd.open_datacollection(path) assert result.equals(dc) + def test_io_create_dirs(self): + da = wavelet_wavefronts() + dc = xd.DataCollection( + { + "das1": da, + "das2": da, + }, + "instrument", + ) + with TemporaryDirectory() as dirpath: + path = os.path.join(dirpath, "subdir", "tmp.nc") + with pytest.raises(FileNotFoundError, match="No such file or directory"): + dc.to_netcdf(path) + dc.to_netcdf(path, create_dirs=True) + result = xd.DataCollection.from_netcdf(path) + assert result.equals(dc) + def test_depth_counter(self): da = wavelet_wavefronts() da.name = "da" diff --git a/xdas/core/dataarray.py b/xdas/core/dataarray.py index 5100cee4..1fa2cd94 100644 --- a/xdas/core/dataarray.py +++ b/xdas/core/dataarray.py @@ -1,4 +1,5 @@ import copy +import os import warnings from functools import partial @@ -825,7 +826,15 @@ def from_stream(cls, st, dims=("channel", "time")): } return cls(data, {dims[0]: channel, dims[1]: time}) - def to_netcdf(self, fname, mode="w", group=None, virtual=None, encoding=None): + def to_netcdf( + self, + fname, + mode="w", + group=None, + virtual=None, + encoding=None, + create_dirs=False, + ): """ Write DataArray contents to a netCDF file. @@ -849,6 +858,8 @@ def to_netcdf(self, fname, mode="w", group=None, virtual=None, encoding=None): the `h5netcdf` engine to write the data. If you want to use a specific plugin for compression, you can use the `hdf5plugin` package. For example, to use the ZFP compression, you can use the `hdf5plugin.Zfp` class. + create_dirs : bool, optional + Whether to create parent directories if they do not exist. Default is False. Examples -------- @@ -883,6 +894,12 @@ def to_netcdf(self, fname, mode="w", group=None, virtual=None, encoding=None): for coord in self.coords.values(): dataset, variable_attrs = coord.to_dataset(dataset, variable_attrs) + # create parent directories if needed + if create_dirs: + dirname = os.path.dirname(fname) + if dirname: + os.makedirs(dirname, exist_ok=True) + # write data with h5netcdf.File(fname, mode=mode) as file: # group diff --git a/xdas/core/datacollection.py b/xdas/core/datacollection.py index 4a95a5be..aafef713 100644 --- a/xdas/core/datacollection.py +++ b/xdas/core/datacollection.py @@ -231,7 +231,15 @@ def fields(self): ) return uniquifiy(out) - def to_netcdf(self, fname, mode="w", group=None, virtual=None, encoding=None): + def to_netcdf( + self, + fname, + mode="w", + group=None, + virtual=None, + encoding=None, + create_dirs=False, + ): if mode == "w" and group is None and os.path.exists(fname): os.remove(fname) for key in self: @@ -239,6 +247,10 @@ def to_netcdf(self, fname, mode="w", group=None, virtual=None, encoding=None): location = "/".join([name, str(key)]) if group is not None: location = "/".join([group, location]) + if create_dirs: + dirname = os.path.dirname(fname) + if dirname: + os.makedirs(dirname, exist_ok=True) self[key].to_netcdf( fname, mode="a", @@ -441,9 +453,22 @@ def to_mapping(self): def from_mapping(cls, data): return cls(data.values(), data.name) - def to_netcdf(self, fname, mode="w", group=None, virtual=None, encoding=None): + def to_netcdf( + self, + fname, + mode="w", + group=None, + virtual=None, + encoding=None, + create_dirs=False, + ): self.to_mapping().to_netcdf( - fname, mode=mode, group=group, virtual=virtual, encoding=encoding + fname, + mode=mode, + group=group, + virtual=virtual, + encoding=encoding, + create_dirs=create_dirs, ) @classmethod