From 52c4727bd19afd5ea1dc9c3ecd6943b2284c75e0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 24 Jul 2025 18:48:52 +0200 Subject: [PATCH 1/6] simplify creating norms --- xarray/core/types.py | 8 +++++-- xarray/plot/dataarray_plot.py | 40 +++++++++++++++++------------------ xarray/plot/dataset_plot.py | 22 +++++++++---------- xarray/plot/utils.py | 30 ++++++++++++++++++++++---- 4 files changed, 63 insertions(+), 37 deletions(-) diff --git a/xarray/core/types.py b/xarray/core/types.py index 736a11f5f17..20b794e98e5 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -200,7 +200,9 @@ def copy( # FYI in some cases we don't allow `None`, which this doesn't take account of. # FYI the `str` is for a size string, e.g. "16MB", supported by dask. -T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None # noqa: PYI051 +T_ChunkDim: TypeAlias = ( + str | int | Literal["auto"] | tuple[int, ...] | None +) # noqa: PYI051 T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim] T_ChunksFreq: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDimFreq] # We allow the tuple form of this (though arguably we could transition to named dims only) @@ -289,7 +291,9 @@ def copy( HueStyleOptions = Literal["continuous", "discrete"] | None AspectOptions = Union[Literal["auto", "equal"], float, None] ExtendOptions = Literal["neither", "both", "min", "max"] | None - +NormOptions = Literal[ + "asinh", "function", "functionlog", "linear", "log", "logit", "symlog" +] _T_co = TypeVar("_T_co", covariant=True) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 537fbd5bafb..110f69a2c84 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -866,7 +866,7 @@ def newplotfunc( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, @@ -1142,7 +1142,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, @@ -1183,7 +1183,7 @@ def scatter( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, @@ -1224,7 +1224,7 @@ def scatter( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, @@ -1438,7 +1438,7 @@ def newplotfunc( yticks: ArrayLike | None = None, xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> Any: # All 2d plots in xarray share this function signature. @@ -1692,7 +1692,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> AxesImage: ... @@ -1732,7 +1732,7 @@ def imshow( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -1772,7 +1772,7 @@ def imshow( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -1909,7 +1909,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> QuadContourSet: ... @@ -1949,7 +1949,7 @@ def contour( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -1989,7 +1989,7 @@ def contour( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2042,7 +2042,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> QuadContourSet: ... @@ -2082,7 +2082,7 @@ def contourf( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2122,7 +2122,7 @@ def contourf( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2175,7 +2175,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> QuadMesh: ... @@ -2215,7 +2215,7 @@ def pcolormesh( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2255,7 +2255,7 @@ def pcolormesh( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2359,7 +2359,7 @@ def surface( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> Poly3DCollection: ... @@ -2399,7 +2399,7 @@ def surface( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2439,7 +2439,7 @@ def surface( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index ff508ee213c..97ac748d146 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -183,7 +183,7 @@ def newplotfunc( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, robust: bool | None = None, @@ -345,7 +345,7 @@ def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -382,7 +382,7 @@ def quiver( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -419,7 +419,7 @@ def quiver( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -496,7 +496,7 @@ def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -533,7 +533,7 @@ def streamplot( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -570,7 +570,7 @@ def streamplot( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -783,7 +783,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, @@ -824,7 +824,7 @@ def scatter( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, @@ -865,7 +865,7 @@ def scatter( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, @@ -906,7 +906,7 @@ def scatter( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a71613562a5..5fe5de41c6d 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -44,7 +44,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import AspectOptions, ScaleOptions + from xarray.core.types import AspectOptions, ScaleOptions, NormOptions try: import matplotlib.pyplot as plt @@ -58,6 +58,26 @@ _LINEWIDTH_RANGE = (1.5, 1.5, 6.0) +def _make_norm_from_string( + norm: NormOptions, +) -> type[Normalize]: + """ + Get norm from string. + + Examples + -------- + >>> _make_norm_from_string("log") + + + """ + from matplotlib.colors import make_norm_from_scale, Normalize + from matplotlib.scale import scale_factory + + scale = type(scale_factory(norm, None)) # type: ignore [arg-type] # mpl issue, use of ax is discouraged + + return make_norm_from_scale(scale, Normalize) + + def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -264,6 +284,8 @@ def _determine_cmap_params( # now check norm and harmonize with vmin, vmax if norm is not None: + norm = _make_norm_from_string(norm)() if isinstance(norm, str) else norm + if norm.vmin is None: norm.vmin = vmin else: @@ -278,9 +300,9 @@ def _determine_cmap_params( raise ValueError("Cannot supply vmax and a norm with a different vmax.") vmax = norm.vmax - # if BoundaryNorm, then set levels - if isinstance(norm, mpl.colors.BoundaryNorm): - levels = norm.boundaries + # if BoundaryNorm, then set levels + if isinstance(norm, mpl.colors.BoundaryNorm): + levels = norm.boundaries # Choose default colormaps if not provided if cmap is None: From af31fcb07f311aaecb9b8e58df6f78a8e0193a9f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Jul 2025 16:51:04 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/types.py | 4 +--- xarray/plot/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/xarray/core/types.py b/xarray/core/types.py index 20b794e98e5..2a3918c838b 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -200,9 +200,7 @@ def copy( # FYI in some cases we don't allow `None`, which this doesn't take account of. # FYI the `str` is for a size string, e.g. "16MB", supported by dask. -T_ChunkDim: TypeAlias = ( - str | int | Literal["auto"] | tuple[int, ...] | None -) # noqa: PYI051 +T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim] T_ChunksFreq: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDimFreq] # We allow the tuple form of this (though arguably we could transition to named dims only) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 5fe5de41c6d..753d85ab280 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -44,7 +44,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import AspectOptions, ScaleOptions, NormOptions + from xarray.core.types import AspectOptions, NormOptions, ScaleOptions try: import matplotlib.pyplot as plt @@ -70,7 +70,7 @@ def _make_norm_from_string( """ - from matplotlib.colors import make_norm_from_scale, Normalize + from matplotlib.colors import Normalize, make_norm_from_scale from matplotlib.scale import scale_factory scale = type(scale_factory(norm, None)) # type: ignore [arg-type] # mpl issue, use of ax is discouraged From 2ad69cf92ed7cd4b311be001ca3d49317ca5aee6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 26 Jul 2025 20:29:49 +0200 Subject: [PATCH 3/6] import also --- xarray/plot/dataarray_plot.py | 1 + xarray/plot/dataset_plot.py | 1 + 2 files changed, 2 insertions(+) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 110f69a2c84..5e8878dc952 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -50,6 +50,7 @@ AspectOptions, ExtendOptions, HueStyleOptions, + NormOptions, ScaleOptions, T_DataArray, ) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 97ac748d146..21641312022 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -30,6 +30,7 @@ AspectOptions, ExtendOptions, HueStyleOptions, + NormOptions, ScaleOptions, ) from xarray.plot.facetgrid import FacetGrid From 8bdf1f4a97720137cca0ab14b5cc44a6b8918828 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 26 Jul 2025 21:13:15 +0200 Subject: [PATCH 4/6] Update xarray/core/types.py --- xarray/core/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/types.py b/xarray/core/types.py index 2a3918c838b..a1226c643b7 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -200,7 +200,7 @@ def copy( # FYI in some cases we don't allow `None`, which this doesn't take account of. # FYI the `str` is for a size string, e.g. "16MB", supported by dask. -T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None +T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None # noqa: PYI051 T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim] T_ChunksFreq: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDimFreq] # We allow the tuple form of this (though arguably we could transition to named dims only) From 26753b7af8d7d2fbcac826ea733e32bc3754309f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 26 Jul 2025 21:17:04 +0200 Subject: [PATCH 5/6] Update xarray/plot/utils.py --- xarray/plot/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 753d85ab280..2cc457d7c09 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -300,9 +300,9 @@ def _determine_cmap_params( raise ValueError("Cannot supply vmax and a norm with a different vmax.") vmax = norm.vmax - # if BoundaryNorm, then set levels - if isinstance(norm, mpl.colors.BoundaryNorm): - levels = norm.boundaries + # if BoundaryNorm, then set levels + if isinstance(norm, mpl.colors.BoundaryNorm): + levels = norm.boundaries # Choose default colormaps if not provided if cmap is None: From e0833b3a378c1d9506528c0ed02e83e0778b61ae Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 26 Jul 2025 21:17:47 +0200 Subject: [PATCH 6/6] Update xarray/plot/utils.py --- xarray/plot/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 2cc457d7c09..cef035dbfb6 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -301,7 +301,7 @@ def _determine_cmap_params( vmax = norm.vmax # if BoundaryNorm, then set levels - if isinstance(norm, mpl.colors.BoundaryNorm): + if isinstance(norm, mpl.colors.BoundaryNorm): levels = norm.boundaries # Choose default colormaps if not provided