diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 2616ab3..cbb4cba 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -89,6 +89,7 @@ def __init__( self._all_sliced_dims: dict[Hashable, int] = dict( **self._unique_batch_dims, **self.input_dims ) + self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) def _gen_batch_selectors(self, ds: xr.DataArray | xr.Dataset) -> BatchSelectorSet: @@ -215,7 +216,10 @@ def _gen_batch_numbers(self, ds: xr.DataArray | xr.Dataset): Calculate the number of batches per dimension """ self._n_batches_per_dim: dict[Hashable, int] = { - dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim])) + dim: int( + self._n_patches_per_dim[dim] + // self._n_patches_per_batch.get(dim, ds.sizes[dim]) + ) for dim in self._all_sliced_dims.keys() } diff --git a/xbatcher/testing.py b/xbatcher/testing.py index 219d592..7fb41e9 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -101,11 +101,7 @@ def _get_sample_length( """ if generator.concat_input_dims: batch_concat_dims = [ - ( - generator.batch_dims.get(dim) // length - if generator.batch_dims.get(dim) - else generator.ds.sizes.get(dim) // length - ) + np.ceil(generator.batch_dims.get(dim, generator.ds.sizes[dim]) / length) for dim, length in generator.input_dims.items() ] else: @@ -252,7 +248,8 @@ def validate_generator_length(generator: BatchGenerator) -> None: ) nbatches_from_duplicate_batch_dims = np.prod( [ - generator.ds.sizes[dim] // length + generator.ds.sizes[dim] + // (generator.input_dims[dim] * np.ceil(length / generator.input_dims[dim])) for dim, length in duplicate_batch_dims.items() ] ) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 166648c..26a7ea7 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -96,7 +96,7 @@ def test_batch_1d(sample_ds_1d, input_size): validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size expected_slice = slice(input_size * n, input_size * (n + 1)) ds_batch_expected = sample_ds_1d.isel(x=expected_slice) xr.testing.assert_identical(ds_batch_expected, ds_batch) @@ -146,7 +146,7 @@ def test_batch_1d_no_coordinate(sample_ds_1d, input_size): validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size expected_slice = slice(input_size * n, input_size * (n + 1)) ds_batch_expected = ds_dropped.isel(x=expected_slice) xr.testing.assert_identical(ds_batch_expected, ds_batch) @@ -187,7 +187,7 @@ def test_batch_1d_overlap(sample_ds_1d, input_overlap): expected_dims = get_batch_dimensions(bg) stride = input_size - input_overlap for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size expected_slice = slice(stride * n, stride * n + input_size) ds_batch_expected = sample_ds_1d.isel(x=expected_slice) xr.testing.assert_identical(ds_batch_expected, ds_batch) @@ -204,11 +204,11 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size): validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size # time and y should be collapsed into batch dimension assert ( - ds_batch.dims['sample'] - == sample_ds_3d.dims['y'] * sample_ds_3d.dims['time'] + ds_batch.sizes['sample'] + == sample_ds_3d.sizes['y'] * sample_ds_3d.sizes['time'] ) expected_slice = slice(input_size * n, input_size * (n + 1)) ds_batch_expected = ( @@ -265,6 +265,23 @@ def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) +def test_batch_3d_uneven_batch_input_dim(sample_ds_3d): + """ + Test batch generation when a batch dimension is not a multiple of the + corresponding input dimension. + """ + bg = BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 10}, + batch_dims={'x': 11, 'y': 21}, + concat_input_dims=True, + ) + validate_generator_length(bg) + expected_dims = get_batch_dimensions(bg) + for ds_batch in bg: + validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) + + @pytest.mark.parametrize('input_size', [5, 10]) def test_batch_3d_2d_input(sample_ds_3d, input_size): """ @@ -279,8 +296,8 @@ def test_batch_3d_2d_input(sample_ds_3d, input_size): yn, xn = np.unravel_index( n, ( - (sample_ds_3d.dims['y'] // input_size), - (sample_ds_3d.dims['x'] // x_input_size), + (sample_ds_3d.sizes['y'] // input_size), + (sample_ds_3d.sizes['x'] // x_input_size), ), ) expected_xslice = slice(x_input_size * xn, x_input_size * (xn + 1))