diff --git a/xarray/core/types.py b/xarray/core/types.py index 736a11f5f17..a1226c643b7 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -289,7 +289,9 @@ def copy( HueStyleOptions = Literal["continuous", "discrete"] | None AspectOptions = Union[Literal["auto", "equal"], float, None] ExtendOptions = Literal["neither", "both", "min", "max"] | None - +NormOptions = Literal[ + "asinh", "function", "functionlog", "linear", "log", "logit", "symlog" +] _T_co = TypeVar("_T_co", covariant=True) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 537fbd5bafb..5e8878dc952 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -50,6 +50,7 @@ AspectOptions, ExtendOptions, HueStyleOptions, + NormOptions, ScaleOptions, T_DataArray, ) @@ -866,7 +867,7 @@ def newplotfunc( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, @@ -1142,7 +1143,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, @@ -1183,7 +1184,7 @@ def scatter( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, @@ -1224,7 +1225,7 @@ def scatter( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, @@ -1438,7 +1439,7 @@ def newplotfunc( yticks: ArrayLike | None = None, xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> Any: # All 2d plots in xarray share this function signature. @@ -1692,7 +1693,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> AxesImage: ... @@ -1732,7 +1733,7 @@ def imshow( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -1772,7 +1773,7 @@ def imshow( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -1909,7 +1910,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> QuadContourSet: ... @@ -1949,7 +1950,7 @@ def contour( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -1989,7 +1990,7 @@ def contour( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2042,7 +2043,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> QuadContourSet: ... @@ -2082,7 +2083,7 @@ def contourf( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2122,7 +2123,7 @@ def contourf( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2175,7 +2176,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> QuadMesh: ... @@ -2215,7 +2216,7 @@ def pcolormesh( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2255,7 +2256,7 @@ def pcolormesh( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2359,7 +2360,7 @@ def surface( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> Poly3DCollection: ... @@ -2399,7 +2400,7 @@ def surface( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... @@ -2439,7 +2440,7 @@ def surface( yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, **kwargs: Any, ) -> FacetGrid[T_DataArray]: ... diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index ff508ee213c..21641312022 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -30,6 +30,7 @@ AspectOptions, ExtendOptions, HueStyleOptions, + NormOptions, ScaleOptions, ) from xarray.plot.facetgrid import FacetGrid @@ -183,7 +184,7 @@ def newplotfunc( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, robust: bool | None = None, @@ -345,7 +346,7 @@ def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -382,7 +383,7 @@ def quiver( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -419,7 +420,7 @@ def quiver( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -496,7 +497,7 @@ def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -533,7 +534,7 @@ def streamplot( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -570,7 +571,7 @@ def streamplot( cbar_ax: Axes | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, infer_intervals: bool | None = None, center: float | None = None, levels: ArrayLike | None = None, @@ -783,7 +784,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, @@ -824,7 +825,7 @@ def scatter( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, @@ -865,7 +866,7 @@ def scatter( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, @@ -906,7 +907,7 @@ def scatter( cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, - norm: Normalize | None = None, + norm: NormOptions | Normalize | None = None, extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a71613562a5..cef035dbfb6 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -44,7 +44,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import AspectOptions, ScaleOptions + from xarray.core.types import AspectOptions, NormOptions, ScaleOptions try: import matplotlib.pyplot as plt @@ -58,6 +58,26 @@ _LINEWIDTH_RANGE = (1.5, 1.5, 6.0) +def _make_norm_from_string( + norm: NormOptions, +) -> type[Normalize]: + """ + Get norm from string. + + Examples + -------- + >>> _make_norm_from_string("log") + + + """ + from matplotlib.colors import Normalize, make_norm_from_scale + from matplotlib.scale import scale_factory + + scale = type(scale_factory(norm, None)) # type: ignore [arg-type] # mpl issue, use of ax is discouraged + + return make_norm_from_scale(scale, Normalize) + + def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -264,6 +284,8 @@ def _determine_cmap_params( # now check norm and harmonize with vmin, vmax if norm is not None: + norm = _make_norm_from_string(norm)() if isinstance(norm, str) else norm + if norm.vmin is None: norm.vmin = vmin else: