Skip to content

interp - Prefer broadcast over reindex when possible #10554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions asv_bench/benchmarks/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,37 @@ def setup(self, *args, **kwargs):
"var1": (("x", "y"), randn_xy),
"var2": (("x", "t"), randn_xt),
"var3": (("t",), randn_t),
"var4": (("z",), np.array(["text"])),
"var5": (("k",), np.array(["a", "b", "c"])),
},
coords={
"x": np.arange(nx),
"y": np.linspace(0, 1, ny),
"t": pd.date_range("1970-01-01", periods=nt, freq="D"),
"x_coords": ("x", np.linspace(1.1, 2.1, nx)),
"z": np.array([1]),
"k": np.linspace(0, nx, 3),
},
)

@parameterized(["method", "is_short"], (["linear", "cubic"], [True, False]))
def time_interpolation(self, method, is_short):
def time_interpolation_numeric_1d(self, method, is_short):
new_x = new_x_short if is_short else new_x_long
self.ds.interp(x=new_x, method=method).load()
self.ds.interp(x=new_x, method=method).compute()

@parameterized(["method"], (["linear", "nearest"]))
def time_interpolation_2d(self, method):
self.ds.interp(x=new_x_long, y=new_y_long, method=method).load()
def time_interpolation_numeric_2d(self, method):
self.ds.interp(x=new_x_long, y=new_y_long, method=method).compute()

@parameterized(["is_short"], ([True, False]))
def time_interpolation_string_scalar(self, is_short):
new_z = new_x_short if is_short else new_x_long
self.ds.interp(z=new_z).compute()

@parameterized(["is_short"], ([True, False]))
def time_interpolation_string_1d(self, is_short):
new_k = new_x_short if is_short else new_x_long
self.ds.interp(k=new_k).compute()


class InterpolationDask(Interpolation):
Expand Down
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ Internal Changes
~~~~~~~~~~~~~~~~


Performance
~~~~~~~~~~~
- Speed up non-numeric scalars when calling :py:meth:`Dataset.interp`. (:issue:`10054`, :pull:`10554`)
By `Jimmy Westling <https://github.com/illviljan>`_.

.. _whats-new.2025.07.1:

v2025.07.1 (July 09, 2025)
Expand Down
23 changes: 16 additions & 7 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3851,13 +3851,22 @@ def _validate_interp_indexer(x, new_x):
var_indexers = {k: v for k, v in use_indexers.items() if k in var.dims}
variables[name] = missing.interp(var, var_indexers, method, **kwargs)
elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims):
# For types that we do not understand do stepwise
# interpolation to avoid modifying the elements.
# reindex the variable instead because it supports
# booleans and objects and retains the dtype but inside
# this loop there might be some duplicate code that slows it
# down, therefore collect these signals and run it later:
reindex_vars.append(name)
if all(var.sizes[d] == 1 for d in (use_indexers.keys() & var.dims)):
# Broadcastable, can be handled quickly without reindex:
to_broadcast = (var.squeeze(),) + tuple(
dest for _, dest in use_indexers.values()
)
variables[name] = broadcast_variables(*to_broadcast)[0].copy(
deep=True
)
else:
# For types that we do not understand do stepwise
# interpolation to avoid modifying the elements.
# reindex the variable instead because it supports
# booleans and objects and retains the dtype but inside
# this loop there might be some duplicate code that slows it
# down, therefore collect these signals and run it later:
reindex_vars.append(name)
elif all(d not in indexers for d in var.dims):
# For anything else we can only keep variables if they
# are not dependent on any coords that are being
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,28 @@ def test_interp1d_complex_out_of_bounds() -> None:
assert_identical(actual, expected)


@requires_scipy
def test_interp_non_numeric_scalar() -> None:
ds = xr.Dataset(
{
"non_numeric": ("time", np.array(["a"])),
},
coords={"time": (np.array([0]))},
)
actual = ds.interp(time=np.linspace(0, 3, 3))

expected = xr.Dataset(
{
"non_numeric": ("time", np.array(["a", "a", "a"])),
},
coords={"time": np.linspace(0, 3, 3)},
)
xr.testing.assert_identical(actual, expected)

# Make sure the array is a copy:
assert actual["non_numeric"].data.base is None


@requires_scipy
def test_interp_non_numeric_1d() -> None:
ds = xr.Dataset(
Expand Down
Loading