diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 10bf1466156..7af83ae87a5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8,123 +8,77 @@ import sys import warnings from collections import defaultdict -from collections.abc import ( - Collection, - Hashable, - Iterable, - Iterator, - Mapping, - MutableMapping, - Sequence, -) +from collections.abc import (Collection, Hashable, Iterable, Iterator, Mapping, + MutableMapping, Sequence) from html import escape from numbers import Number from operator import methodcaller from os import PathLike -from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload +from typing import (IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, + overload) import numpy as np - -# remove once numpy 2.0 is the oldest supported version -try: - from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] -except ImportError: - from numpy import RankWarning - import pandas as pd +from line_profiler import profile as codeflash_line_profile +codeflash_line_profile.enable(output_prefix='/tmp/codeflash_hyrdbv3n/baseline_lprof') from xarray.coding.calendar_ops import convert_calendar, interp_calendar -from xarray.coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings -from xarray.core import ( - alignment, - duck_array_ops, - formatting, - formatting_html, - ops, - utils, -) +from xarray.coding.cftimeindex import (CFTimeIndex, + _parse_array_of_cftime_strings) +from xarray.core import alignment from xarray.core import dtypes as xrdtypes +from xarray.core import duck_array_ops, formatting, formatting_html, ops, utils from xarray.core._aggregations import DatasetAggregations -from xarray.core.alignment import ( - _broadcast_helper, - _get_broadcast_dims_map_common_coords, - align, -) +from xarray.core.alignment import (_broadcast_helper, + _get_broadcast_dims_map_common_coords, + align) from xarray.core.arithmetic import DatasetArithmetic -from xarray.core.common import ( - DataWithCoords, - _contains_datetime_like_objects, - get_chunksizes, -) +from xarray.core.common import (DataWithCoords, + _contains_datetime_like_objects, + get_chunksizes) from xarray.core.computation import unify_chunks -from xarray.core.coordinates import ( - Coordinates, - DatasetCoordinates, - assert_coordinate_consistent, - create_coords_with_default_indexes, -) +from xarray.core.coordinates import (Coordinates, DatasetCoordinates, + assert_coordinate_consistent, + create_coords_with_default_indexes) from xarray.core.duck_array_ops import datetime_to_numeric -from xarray.core.indexes import ( - Index, - Indexes, - PandasIndex, - PandasMultiIndex, - assert_no_index_corrupted, - create_default_index_implicit, - filter_indexes_from_coords, - isel_indexes, - remove_unused_levels_categories, - roll_indexes, -) +from xarray.core.indexes import (Index, Indexes, PandasIndex, PandasMultiIndex, + assert_no_index_corrupted, + create_default_index_implicit, + filter_indexes_from_coords, isel_indexes, + remove_unused_levels_categories, roll_indexes) from xarray.core.indexing import is_fancy_indexer, map_index_queries -from xarray.core.merge import ( - dataset_merge_method, - dataset_update_method, - merge_coordinates_without_align, - merge_core, -) +from xarray.core.merge import (dataset_merge_method, dataset_update_method, + merge_coordinates_without_align, merge_core) from xarray.core.missing import get_clean_interp_index from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.types import ( - QuantileMethods, - Self, - T_ChunkDim, - T_Chunks, - T_DataArray, - T_DataArrayOrSet, - T_Dataset, - ZarrWriteModes, -) -from xarray.core.utils import ( - Default, - Frozen, - FrozenMappingWarningOnValuesAccess, - HybridMappingProxy, - OrderedSet, - _default, - decode_numpy_dict_values, - drop_dims_from_indexers, - either_dict_or_kwargs, - emit_user_level_warning, - infix_dims, - is_dict_like, - is_duck_array, - is_duck_dask_array, - is_scalar, - maybe_wrap_array, -) -from xarray.core.variable import ( - IndexVariable, - Variable, - as_variable, - broadcast_variables, - calculate_dimensions, -) -from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager +from xarray.core.types import (QuantileMethods, Self, T_ChunkDim, T_Chunks, + T_DataArray, T_DataArrayOrSet, T_Dataset, + ZarrWriteModes) +from xarray.core.utils import (Default, Frozen, + FrozenMappingWarningOnValuesAccess, + HybridMappingProxy, OrderedSet, _default, + decode_numpy_dict_values, + drop_dims_from_indexers, either_dict_or_kwargs, + emit_user_level_warning, infix_dims, + is_dict_like, is_duck_array, is_duck_dask_array, + is_scalar, maybe_wrap_array) +from xarray.core.variable import (IndexVariable, Variable, as_variable, + broadcast_variables, calculate_dimensions) +from xarray.namedarray.parallelcompat import (get_chunked_array_type, + guess_chunkmanager) from xarray.namedarray.pycompat import array_type, is_chunked_array from xarray.plot.accessor import DatasetPlotAccessor from xarray.util.deprecation_helpers import _deprecate_positional_args +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import \ + RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning + + + if TYPE_CHECKING: from dask.dataframe import DataFrame as DaskDataFrame from dask.delayed import Delayed @@ -134,31 +88,19 @@ from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes from xarray.core.dataarray import DataArray from xarray.core.groupby import DatasetGroupBy - from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult + from xarray.core.merge import (CoercibleMapping, CoercibleValue, + _MergeResult) from xarray.core.resample import DatasetResample from xarray.core.rolling import DatasetCoarsen, DatasetRolling - from xarray.core.types import ( - CFCalendar, - CoarsenBoundaryOptions, - CombineAttrsOptions, - CompatOptions, - DataVars, - DatetimeLike, - DatetimeUnitOptions, - Dims, - DsCompatible, - ErrorOptions, - ErrorOptionsWithWarn, - InterpOptions, - JoinOptions, - PadModeOptions, - PadReflectOptions, - QueryEngineOptions, - QueryParserOptions, - ReindexMethodOptions, - SideOptions, - T_Xarray, - ) + from xarray.core.types import (CFCalendar, CoarsenBoundaryOptions, + CombineAttrsOptions, CompatOptions, + DataVars, DatetimeLike, DatetimeUnitOptions, + Dims, DsCompatible, ErrorOptions, + ErrorOptionsWithWarn, InterpOptions, + JoinOptions, PadModeOptions, + PadReflectOptions, QueryEngineOptions, + QueryParserOptions, ReindexMethodOptions, + SideOptions, T_Xarray) from xarray.core.weighted import DatasetWeighted from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -410,6 +352,7 @@ def _initialize_feasible(lb, ub): return param_defaults, bounds_defaults +@codeflash_line_profile def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult: """Used in Dataset.__init__.""" if isinstance(coords, Coordinates): @@ -2934,7 +2877,9 @@ def isel( """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") - if any(is_fancy_indexer(idx) for idx in indexers.values()): + # Use a local variable to avoid repeated lookup in the loop + indexers_values = indexers.values() + if any(is_fancy_indexer(idx) for idx in indexers_values): return self._isel_fancy(indexers, drop=drop, missing_dims=missing_dims) # Much faster algorithm for when all indexers are ints, slices, one-dimensional @@ -2947,15 +2892,23 @@ def isel( indexes, index_variables = isel_indexes(self.xindexes, indexers) + # Speed: use precomputed set intersection for clarity in loop + coord_names_set = coord_names + for name, var in self._variables.items(): # preserve variable order if name in index_variables: var = index_variables[name] else: - var_indexers = {k: v for k, v in indexers.items() if k in var.dims} + # Avoid building dict when unnecessary + var_indexers = None + if var.dims: + shared = [k for k in indexers if k in var.dims] + if shared: + var_indexers = {k: indexers[k] for k in shared} if var_indexers: var = var.isel(var_indexers) - if drop and var.ndim == 0 and name in coord_names: + if drop and var.ndim == 0 and name in coord_names_set: coord_names.remove(name) continue variables[name] = var @@ -10178,12 +10131,9 @@ def groupby( Dataset.resample DataArray.resample """ - from xarray.core.groupby import ( - DatasetGroupBy, - ResolvedGrouper, - UniqueGrouper, - _validate_groupby_squeeze, - ) + from xarray.core.groupby import (DatasetGroupBy, ResolvedGrouper, + UniqueGrouper, + _validate_groupby_squeeze) _validate_groupby_squeeze(squeeze) rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) @@ -10263,12 +10213,9 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import ( - BinGrouper, - DatasetGroupBy, - ResolvedGrouper, - _validate_groupby_squeeze, - ) + from xarray.core.groupby import (BinGrouper, DatasetGroupBy, + ResolvedGrouper, + _validate_groupby_squeeze) _validate_groupby_squeeze(squeeze) grouper = BinGrouper( diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 804e1cfd795..371bf783f18 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -327,47 +327,58 @@ def _infer_xy_labels_3d( assert rgb is None or rgb != y # Start by detecting and reporting invalid combinations of arguments assert darray.ndim == 3 - not_none = [a for a in (x, y, rgb) if a is not None] - if len(set(not_none)) < len(not_none): - raise ValueError( - "Dimension names must be None or unique strings, but imshow was " - f"passed x={x!r}, y={y!r}, and rgb={rgb!r}." - ) - for label in not_none: - if label not in darray.dims: - raise ValueError(f"{label!r} is not a dimension") - - # Then calculate rgb dimension if certain and check validity - could_be_color = [ - label - for label in darray.dims - if darray[label].size in (3, 4) and label not in (x, y) - ] - if rgb is None and not could_be_color: - raise ValueError( - "A 3-dimensional array was passed to imshow(), but there is no " - "dimension that could be color. At least one dimension must be " - "of size 3 (RGB) or 4 (RGBA), and not given as x or y." - ) - if rgb is None and len(could_be_color) == 1: - rgb = could_be_color[0] - if rgb is not None and darray[rgb].size not in (3, 4): - raise ValueError( - f"Cannot interpret dim {rgb!r} of size {darray[rgb].size} as RGB or RGBA." - ) + dims = darray.dims + + # Validate uniqueness and existence of dimensions + vals = (x, y, rgb) + seen = set() + for label in vals: + if label is not None: + if label in seen: + raise ValueError( + "Dimension names must be None or unique strings, but imshow was " + f"passed x={x!r}, y={y!r}, and rgb={rgb!r}." + ) + seen.add(label) + if label not in dims: + raise ValueError(f"{label!r} is not a dimension") + + # Find eligible color dimension(s) + could_be_color = [] + for label in dims: + if label not in (x, y): + size = darray[label].size + if size == 3 or size == 4: + could_be_color.append(label) # If rgb dimension is still unknown, there must be two or three dimensions # in could_be_color. We therefore warn, and use a heuristic to break ties. if rgb is None: - assert len(could_be_color) in (2, 3) - rgb = could_be_color[-1] - warnings.warn( - "Several dimensions of this array could be colors. Xarray " - f"will use the last possible dimension ({rgb!r}) to match " - "matplotlib.pyplot.imshow. You can pass names of x, y, " - "and/or rgb dimensions to override this guess." - ) - assert rgb is not None + if not could_be_color: + raise ValueError( + "A 3-dimensional array was passed to imshow(), but there is no " + "dimension that could be color. At least one dimension must be " + "of size 3 (RGB) or 4 (RGBA), and not given as x or y." + ) + if len(could_be_color) == 1: + rgb = could_be_color[0] + else: + # There must be 2 or 3 possible color dims; warn and pick last (matplotlib default) + rgb = could_be_color[-1] + warnings.warn( + "Several dimensions of this array could be colors. Xarray " + f"will use the last possible dimension ({rgb!r}) to match " + "matplotlib.pyplot.imshow. You can pass names of x, y, " + "and/or rgb dimensions to override this guess." + ) + else: + rgb_size = darray[rgb].size + if rgb_size not in (3, 4): + raise ValueError( + f"Cannot interpret dim {rgb!r} of size {rgb_size} as RGB or RGBA." + ) + + # Finally, we pick out the red slice and delegate to the 2D version: # Finally, we pick out the red slice and delegate to the 2D version: return _infer_xy_labels(darray.isel({rgb: 0}), x, y) @@ -385,6 +396,8 @@ def _infer_xy_labels( darray must be a 2 dimensional data array, or 3d for imshow only. """ + dims = darray.dims + if (x is not None) and (x == y): raise ValueError("x and y cannot be equal.") @@ -394,18 +407,25 @@ def _infer_xy_labels( if x is None and y is None: if darray.ndim != 2: raise ValueError("DataArray must be 2d") - y, x = darray.dims + y, x = dims elif x is None: _assert_valid_xy(darray, y, "y") - x = darray.dims[0] if y == darray.dims[1] else darray.dims[1] + # Pick first non-y dim as x + if y == dims[1]: + x = dims[0] + else: + x = dims[1] elif y is None: _assert_valid_xy(darray, x, "x") - y = darray.dims[0] if x == darray.dims[1] else darray.dims[1] + if x == dims[1]: + y = dims[0] + else: + y = dims[1] else: _assert_valid_xy(darray, x, "x") _assert_valid_xy(darray, y, "y") - - if darray._indexes.get(x, 1) is darray._indexes.get(y, 2): + # Use is instead of == to check if both are referencing the same object + if darray._indexes.get(x) is darray._indexes.get(y): if isinstance(darray._indexes[x], PandasMultiIndex): raise ValueError("x and y cannot be levels of the same MultiIndex")