Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
self._all_sliced_dims: dict[Hashable, int] = dict(
**self._unique_batch_dims, **self.input_dims
)

self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds)

def _gen_batch_selectors(self, ds: xr.DataArray | xr.Dataset) -> BatchSelectorSet:
Expand Down Expand Up @@ -215,7 +216,10 @@ def _gen_batch_numbers(self, ds: xr.DataArray | xr.Dataset):
Calculate the number of batches per dimension
"""
self._n_batches_per_dim: dict[Hashable, int] = {
dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim]))
dim: int(
self._n_patches_per_dim[dim]
// self._n_patches_per_batch.get(dim, ds.sizes[dim])
)
for dim in self._all_sliced_dims.keys()
}

Expand Down
9 changes: 3 additions & 6 deletions xbatcher/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,7 @@ def _get_sample_length(
"""
if generator.concat_input_dims:
batch_concat_dims = [
(
generator.batch_dims.get(dim) // length
if generator.batch_dims.get(dim)
else generator.ds.sizes.get(dim) // length
)
np.ceil(generator.batch_dims.get(dim, generator.ds.sizes[dim]) / length)
for dim, length in generator.input_dims.items()
]
else:
Expand Down Expand Up @@ -252,7 +248,8 @@ def validate_generator_length(generator: BatchGenerator) -> None:
)
nbatches_from_duplicate_batch_dims = np.prod(
[
generator.ds.sizes[dim] // length
generator.ds.sizes[dim]
// (generator.input_dims[dim] * np.ceil(length / generator.input_dims[dim]))
for dim, length in duplicate_batch_dims.items()
]
)
Expand Down
33 changes: 25 additions & 8 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_batch_1d(sample_ds_1d, input_size):
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
assert ds_batch.dims['x'] == input_size
assert ds_batch.sizes['x'] == input_size
expected_slice = slice(input_size * n, input_size * (n + 1))
ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
xr.testing.assert_identical(ds_batch_expected, ds_batch)
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_batch_1d_no_coordinate(sample_ds_1d, input_size):
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
assert ds_batch.dims['x'] == input_size
assert ds_batch.sizes['x'] == input_size
expected_slice = slice(input_size * n, input_size * (n + 1))
ds_batch_expected = ds_dropped.isel(x=expected_slice)
xr.testing.assert_identical(ds_batch_expected, ds_batch)
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_batch_1d_overlap(sample_ds_1d, input_overlap):
expected_dims = get_batch_dimensions(bg)
stride = input_size - input_overlap
for n, ds_batch in enumerate(bg):
assert ds_batch.dims['x'] == input_size
assert ds_batch.sizes['x'] == input_size
expected_slice = slice(stride * n, stride * n + input_size)
ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
xr.testing.assert_identical(ds_batch_expected, ds_batch)
Expand All @@ -204,11 +204,11 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size):
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
assert ds_batch.dims['x'] == input_size
assert ds_batch.sizes['x'] == input_size
# time and y should be collapsed into batch dimension
assert (
ds_batch.dims['sample']
== sample_ds_3d.dims['y'] * sample_ds_3d.dims['time']
ds_batch.sizes['sample']
== sample_ds_3d.sizes['y'] * sample_ds_3d.sizes['time']
)
expected_slice = slice(input_size * n, input_size * (n + 1))
ds_batch_expected = (
Expand Down Expand Up @@ -265,6 +265,23 @@ def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d):
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


def test_batch_3d_uneven_batch_input_dim(sample_ds_3d):
"""
Test batch generation when a batch dimension is not a multiple of the
corresponding input dimension.
"""
bg = BatchGenerator(
sample_ds_3d,
input_dims={'x': 5, 'y': 10},
batch_dims={'x': 11, 'y': 21},
concat_input_dims=True,
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for ds_batch in bg:
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


@pytest.mark.parametrize('input_size', [5, 10])
def test_batch_3d_2d_input(sample_ds_3d, input_size):
"""
Expand All @@ -279,8 +296,8 @@ def test_batch_3d_2d_input(sample_ds_3d, input_size):
yn, xn = np.unravel_index(
n,
(
(sample_ds_3d.dims['y'] // input_size),
(sample_ds_3d.dims['x'] // x_input_size),
(sample_ds_3d.sizes['y'] // input_size),
(sample_ds_3d.sizes['x'] // x_input_size),
),
)
expected_xslice = slice(x_input_size * xn, x_input_size * (xn + 1))
Expand Down
Loading