diff --git a/asv_bench/benchmarks/interp.py b/asv_bench/benchmarks/interp.py index 4b6691bcc0a..ca1d0a2dd89 100644 --- a/asv_bench/benchmarks/interp.py +++ b/asv_bench/benchmarks/interp.py @@ -25,23 +25,37 @@ def setup(self, *args, **kwargs): "var1": (("x", "y"), randn_xy), "var2": (("x", "t"), randn_xt), "var3": (("t",), randn_t), + "var4": (("z",), np.array(["text"])), + "var5": (("k",), np.array(["a", "b", "c"])), }, coords={ "x": np.arange(nx), "y": np.linspace(0, 1, ny), "t": pd.date_range("1970-01-01", periods=nt, freq="D"), "x_coords": ("x", np.linspace(1.1, 2.1, nx)), + "z": np.array([1]), + "k": np.linspace(0, nx, 3), }, ) @parameterized(["method", "is_short"], (["linear", "cubic"], [True, False])) - def time_interpolation(self, method, is_short): + def time_interpolation_numeric_1d(self, method, is_short): new_x = new_x_short if is_short else new_x_long - self.ds.interp(x=new_x, method=method).load() + self.ds.interp(x=new_x, method=method).compute() @parameterized(["method"], (["linear", "nearest"])) - def time_interpolation_2d(self, method): - self.ds.interp(x=new_x_long, y=new_y_long, method=method).load() + def time_interpolation_numeric_2d(self, method): + self.ds.interp(x=new_x_long, y=new_y_long, method=method).compute() + + @parameterized(["is_short"], ([True, False])) + def time_interpolation_string_scalar(self, is_short): + new_z = new_x_short if is_short else new_x_long + self.ds.interp(z=new_z).compute() + + @parameterized(["is_short"], ([True, False])) + def time_interpolation_string_1d(self, is_short): + new_k = new_x_short if is_short else new_x_long + self.ds.interp(k=new_k).compute() class InterpolationDask(Interpolation): diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 39c6a8924f4..bace038bb17 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,11 @@ Internal Changes ~~~~~~~~~~~~~~~~ +Performance +~~~~~~~~~~~ +- Speed up non-numeric scalars when calling :py:meth:`Dataset.interp`. (:issue:`10054`, :pull:`10554`) + By `Jimmy Westling `_. + .. _whats-new.2025.07.1: v2025.07.1 (July 09, 2025) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 26db282c3df..f79df3da7c2 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3851,13 +3851,22 @@ def _validate_interp_indexer(x, new_x): var_indexers = {k: v for k, v in use_indexers.items() if k in var.dims} variables[name] = missing.interp(var, var_indexers, method, **kwargs) elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims): - # For types that we do not understand do stepwise - # interpolation to avoid modifying the elements. - # reindex the variable instead because it supports - # booleans and objects and retains the dtype but inside - # this loop there might be some duplicate code that slows it - # down, therefore collect these signals and run it later: - reindex_vars.append(name) + if all(var.sizes[d] == 1 for d in (use_indexers.keys() & var.dims)): + # Broadcastable, can be handled quickly without reindex: + to_broadcast = (var.squeeze(),) + tuple( + dest for _, dest in use_indexers.values() + ) + variables[name] = broadcast_variables(*to_broadcast)[0].copy( + deep=True + ) + else: + # For types that we do not understand do stepwise + # interpolation to avoid modifying the elements. + # reindex the variable instead because it supports + # booleans and objects and retains the dtype but inside + # this loop there might be some duplicate code that slows it + # down, therefore collect these signals and run it later: + reindex_vars.append(name) elif all(d not in indexers for d in var.dims): # For anything else we can only keep variables if they # are not dependent on any coords that are being diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 7d5a9bf3db4..dd3906cfd59 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -1065,6 +1065,28 @@ def test_interp1d_complex_out_of_bounds() -> None: assert_identical(actual, expected) +@requires_scipy +def test_interp_non_numeric_scalar() -> None: + ds = xr.Dataset( + { + "non_numeric": ("time", np.array(["a"])), + }, + coords={"time": (np.array([0]))}, + ) + actual = ds.interp(time=np.linspace(0, 3, 3)) + + expected = xr.Dataset( + { + "non_numeric": ("time", np.array(["a", "a", "a"])), + }, + coords={"time": np.linspace(0, 3, 3)}, + ) + xr.testing.assert_identical(actual, expected) + + # Make sure the array is a copy: + assert actual["non_numeric"].data.base is None + + @requires_scipy def test_interp_non_numeric_1d() -> None: ds = xr.Dataset(