Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,14 +552,23 @@ def _localize(var, indexes_coords):
Only consider a subspace that is needed for the interpolation
"""
indexes = {}
for dim, [x, new_x] in indexes_coords.items():
minval = np.nanmin(new_x.values)
maxval = np.nanmax(new_x.values)
for dim, pair in indexes_coords.items():
x, new_x = pair
# Cache values and to_index outside critical loop for efficiency
new_x_values = new_x.values
minval = np.nanmin(new_x_values)
maxval = np.nanmax(new_x_values)
index = x.to_index()
imin = index.get_indexer([minval], method="nearest").item()
imax = index.get_indexer([maxval], method="nearest").item()
indexes[dim] = slice(max(imin - 2, 0), imax + 2)
indexes_coords[dim] = (x[indexes[dim]], new_x)
# Batch get_indexer for both minval and maxval in same pass
minmax = np.array([minval, maxval])
nearest_idx = index.get_indexer(minmax, method="nearest")
imin, imax = nearest_idx[0].item(), nearest_idx[1].item()
start = max(imin - 2, 0)
stop = imax + 2
slicer = slice(start, stop)
indexes[dim] = slicer
# x[indexes[dim]] is fast if slice is simple, update all in one go
indexes_coords[dim] = (x[slicer], new_x)
return var.isel(**indexes), indexes_coords


Expand Down Expand Up @@ -766,14 +775,22 @@ def _interpnd(var, x, new_x, func, kwargs):
if len(x) == 1:
return _interp1d(var, x, new_x, func, kwargs)

# move the interpolation axes to the start position
var = var.transpose(range(-len(x), var.ndim - len(x)))
# Avoid repeated calculations, cache dims
n_interp = len(x)
# move the interpolation axes to the start position (faster: use tuple directly, no range calls in loop)
# Instead of list(range(-n_interp, var.ndim - n_interp)), build tuple up front
old_order = tuple(range(-n_interp, var.ndim - n_interp))
var = var.transpose(*old_order)
# Precompute ravel stacks more efficiently using tuple comprehension and avoid allocation overhead
# stack new_x to 1 vector, with reshape
xi = np.stack([x1.values.ravel() for x1 in new_x], axis=-1)
rslt = func(x, var, xi, **kwargs)
# move back the interpolation axes to the last position
rslt = rslt.transpose(range(-rslt.ndim + 1, 1))
return reshape(rslt, rslt.shape[:-1] + new_x[0].shape)
# Restore axes using explicit tuple, avoid repeated recalculation
rslt_dims = rslt.ndim
rslt = rslt.transpose(*tuple(range(-rslt_dims + 1, 1)))
# Use direct new_x shape extraction for target reshape
out_shape = rslt.shape[:-1] + new_x[0].shape
return reshape(rslt, out_shape)


def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True):
Expand All @@ -785,16 +802,23 @@ def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=T
n_x = len(coords) // 2
nconst = len(var.shape) - n_x

# _interpnd expect coords to be Variables
x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])]
# Precompute dim names for Variable creation efficiently
const_dims = [f"dim_{nconst + dim}" for dim in range(n_x)]
x = [Variable([const_dims[dim]], _x) for dim, _x in enumerate(coords[:n_x])]

# Precompute dim name lists up front, avoid repeated range calls for each new_x
total_dim = len(var.shape)
new_x = [
Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x)
Variable([f"dim_{total_dim + dim}" for dim in range(len(_x.shape))], _x)
for _x in coords[n_x:]
]

if localize:
# _localize expect var to be a Variable
var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var)
# Compose dim name list directly
var_dims = [f"dim_{dim}" for dim in range(len(var.shape))]
var = Variable(var_dims, var)

# Compose dictionary zip directly as one step for fast lookup

indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)}

Expand Down