From 17d92c7deaaa706f623aeb0cf72c00d26f7000c7 Mon Sep 17 00:00:00 2001 From: s-kganz Date: Fri, 7 Nov 2025 15:42:49 -0800 Subject: [PATCH 1/3] add check for partial patches, corresponding test --- xbatcher/generators.py | 10 ++++++++++ xbatcher/tests/test_generators.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 2616ab3..8f1055f 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -89,6 +89,16 @@ def __init__( self._all_sliced_dims: dict[Hashable, int] = dict( **self._unique_batch_dims, **self.input_dims ) + + # Check that duplicate dims imply whole patches per batch + for dim, length in self.batch_dims.items(): + input_dim_length = self.input_dims.get(dim) + if input_dim_length is not None and length % input_dim_length != 0: + raise ValueError( + f'Input and batch dimension sizes imply partial batches ' + f'on dimension {dim}. Input size: {input_dim_length}; Batch size: {length}' + ) + self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) def _gen_batch_selectors(self, ds: xr.DataArray | xr.Dataset) -> BatchSelectorSet: diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 166648c..2bf4124 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -265,6 +265,20 @@ def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) +def test_batch_3d_uneven_batch_input_dim(sample_ds_3d): + """ + Test for error when a batch dimension is not a multiple of the + corresponding input dimension. + """ + with pytest.raises(ValueError, match='imply partial batches'): + _ = BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 10}, + batch_dims={'x': 11, 'y': 21}, + concat_input_dims=True, + ) + + @pytest.mark.parametrize('input_size', [5, 10]) def test_batch_3d_2d_input(sample_ds_3d, input_size): """ From f5f04b03c0833480ec538d9f0ffba01bbf3956ae Mon Sep 17 00:00:00 2001 From: s-kganz Date: Mon, 8 Dec 2025 15:15:26 -0800 Subject: [PATCH 2/3] modify batch per dim calculation, update tests for consistency --- xbatcher/generators.py | 14 ++++---------- xbatcher/testing.py | 9 +++------ xbatcher/tests/test_generators.py | 19 +++++++++++-------- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 8f1055f..cbb4cba 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -90,15 +90,6 @@ def __init__( **self._unique_batch_dims, **self.input_dims ) - # Check that duplicate dims imply whole patches per batch - for dim, length in self.batch_dims.items(): - input_dim_length = self.input_dims.get(dim) - if input_dim_length is not None and length % input_dim_length != 0: - raise ValueError( - f'Input and batch dimension sizes imply partial batches ' - f'on dimension {dim}. Input size: {input_dim_length}; Batch size: {length}' - ) - self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) def _gen_batch_selectors(self, ds: xr.DataArray | xr.Dataset) -> BatchSelectorSet: @@ -225,7 +216,10 @@ def _gen_batch_numbers(self, ds: xr.DataArray | xr.Dataset): Calculate the number of batches per dimension """ self._n_batches_per_dim: dict[Hashable, int] = { - dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim])) + dim: int( + self._n_patches_per_dim[dim] + // self._n_patches_per_batch.get(dim, ds.sizes[dim]) + ) for dim in self._all_sliced_dims.keys() } diff --git a/xbatcher/testing.py b/xbatcher/testing.py index 219d592..7fb41e9 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -101,11 +101,7 @@ def _get_sample_length( """ if generator.concat_input_dims: batch_concat_dims = [ - ( - generator.batch_dims.get(dim) // length - if generator.batch_dims.get(dim) - else generator.ds.sizes.get(dim) // length - ) + np.ceil(generator.batch_dims.get(dim, generator.ds.sizes[dim]) / length) for dim, length in generator.input_dims.items() ] else: @@ -252,7 +248,8 @@ def validate_generator_length(generator: BatchGenerator) -> None: ) nbatches_from_duplicate_batch_dims = np.prod( [ - generator.ds.sizes[dim] // length + generator.ds.sizes[dim] + // (generator.input_dims[dim] * np.ceil(length / generator.input_dims[dim])) for dim, length in duplicate_batch_dims.items() ] ) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 2bf4124..92ca754 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -267,16 +267,19 @@ def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): def test_batch_3d_uneven_batch_input_dim(sample_ds_3d): """ - Test for error when a batch dimension is not a multiple of the + Test batch generation when a batch dimension is not a multiple of the corresponding input dimension. """ - with pytest.raises(ValueError, match='imply partial batches'): - _ = BatchGenerator( - sample_ds_3d, - input_dims={'x': 5, 'y': 10}, - batch_dims={'x': 11, 'y': 21}, - concat_input_dims=True, - ) + bg = BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 10}, + batch_dims={'x': 11, 'y': 21}, + concat_input_dims=True, + ) + validate_generator_length(bg) + expected_dims = get_batch_dimensions(bg) + for ds_batch in bg: + validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) @pytest.mark.parametrize('input_size', [5, 10]) From e6a06e476d434339b7c3114e66d572064c2d5ad4 Mon Sep 17 00:00:00 2001 From: s-kganz Date: Mon, 8 Dec 2025 15:17:59 -0800 Subject: [PATCH 3/3] .dims --> .sizes to silence test warnings --- xbatcher/tests/test_generators.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 92ca754..26a7ea7 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -96,7 +96,7 @@ def test_batch_1d(sample_ds_1d, input_size): validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size expected_slice = slice(input_size * n, input_size * (n + 1)) ds_batch_expected = sample_ds_1d.isel(x=expected_slice) xr.testing.assert_identical(ds_batch_expected, ds_batch) @@ -146,7 +146,7 @@ def test_batch_1d_no_coordinate(sample_ds_1d, input_size): validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size expected_slice = slice(input_size * n, input_size * (n + 1)) ds_batch_expected = ds_dropped.isel(x=expected_slice) xr.testing.assert_identical(ds_batch_expected, ds_batch) @@ -187,7 +187,7 @@ def test_batch_1d_overlap(sample_ds_1d, input_overlap): expected_dims = get_batch_dimensions(bg) stride = input_size - input_overlap for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size expected_slice = slice(stride * n, stride * n + input_size) ds_batch_expected = sample_ds_1d.isel(x=expected_slice) xr.testing.assert_identical(ds_batch_expected, ds_batch) @@ -204,11 +204,11 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size): validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size # time and y should be collapsed into batch dimension assert ( - ds_batch.dims['sample'] - == sample_ds_3d.dims['y'] * sample_ds_3d.dims['time'] + ds_batch.sizes['sample'] + == sample_ds_3d.sizes['y'] * sample_ds_3d.sizes['time'] ) expected_slice = slice(input_size * n, input_size * (n + 1)) ds_batch_expected = ( @@ -296,8 +296,8 @@ def test_batch_3d_2d_input(sample_ds_3d, input_size): yn, xn = np.unravel_index( n, ( - (sample_ds_3d.dims['y'] // input_size), - (sample_ds_3d.dims['x'] // x_input_size), + (sample_ds_3d.sizes['y'] // input_size), + (sample_ds_3d.sizes['x'] // x_input_size), ), ) expected_xslice = slice(x_input_size * xn, x_input_size * (xn + 1))