diff --git a/regridding/_regrid/_tests/test_regrid.py b/regridding/_regrid/_tests/test_regrid.py index 35ef7c9..97f2878 100644 --- a/regridding/_regrid/_tests/test_regrid.py +++ b/regridding/_regrid/_tests/test_regrid.py @@ -39,6 +39,15 @@ None, np.square(np.linspace(-1, 1, num=11)), ), + ( + (np.linspace(-1, 1, num=11) * u.mm,), + (np.linspace(-1, 1, num=11) * u.mm,), + np.square(np.linspace(-1, 1, num=11)), + None, + None, + None, + np.square(np.linspace(-1, 1, num=11)), + ), ( (np.linspace(-1, 1, num=11),), (np.linspace(-1, 1, num=11),), diff --git a/regridding/_util.py b/regridding/_util.py index 3cf7004..8c1c62d 100644 --- a/regridding/_util.py +++ b/regridding/_util.py @@ -32,6 +32,15 @@ def _normalize_input_output_coordinates( if isinstance(coordinates_output, np.ndarray): coordinates_output = (coordinates_output,) + coordinates_input = list(coordinates_input) + for i in range(len(coordinates_input)): + coord_input = coordinates_input[i] + coord_output = coordinates_output[i] + if hasattr(coord_output, "unit"): + coord_input = coord_input << coord_output.unit + coordinates_input[i] = coord_input + coordinates_input = tuple(coordinates_input) + shape_coordinates_input = np.broadcast(*coordinates_input).shape shape_coordinates_output = np.broadcast(*coordinates_output).shape