Skip to content

Allow creating norms from strings #10566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ def copy(
HueStyleOptions = Literal["continuous", "discrete"] | None
AspectOptions = Union[Literal["auto", "equal"], float, None]
ExtendOptions = Literal["neither", "both", "min", "max"] | None

NormOptions = Literal[
"asinh", "function", "functionlog", "linear", "log", "logit", "symlog"
]

_T_co = TypeVar("_T_co", covariant=True)

Expand Down
41 changes: 21 additions & 20 deletions xarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
AspectOptions,
ExtendOptions,
HueStyleOptions,
NormOptions,
ScaleOptions,
T_DataArray,
)
Expand Down Expand Up @@ -866,7 +867,7 @@ def newplotfunc(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs,
Expand Down Expand Up @@ -1142,7 +1143,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs,
Expand Down Expand Up @@ -1183,7 +1184,7 @@ def scatter(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs,
Expand Down Expand Up @@ -1224,7 +1225,7 @@ def scatter(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs,
Expand Down Expand Up @@ -1438,7 +1439,7 @@ def newplotfunc(
yticks: ArrayLike | None = None,
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> Any:
# All 2d plots in xarray share this function signature.
Expand Down Expand Up @@ -1692,7 +1693,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> AxesImage: ...

Expand Down Expand Up @@ -1732,7 +1733,7 @@ def imshow(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -1772,7 +1773,7 @@ def imshow(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -1909,7 +1910,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> QuadContourSet: ...

Expand Down Expand Up @@ -1949,7 +1950,7 @@ def contour(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -1989,7 +1990,7 @@ def contour(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2042,7 +2043,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> QuadContourSet: ...

Expand Down Expand Up @@ -2082,7 +2083,7 @@ def contourf(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2122,7 +2123,7 @@ def contourf(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2175,7 +2176,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> QuadMesh: ...

Expand Down Expand Up @@ -2215,7 +2216,7 @@ def pcolormesh(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2255,7 +2256,7 @@ def pcolormesh(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2359,7 +2360,7 @@ def surface(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> Poly3DCollection: ...

Expand Down Expand Up @@ -2399,7 +2400,7 @@ def surface(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2439,7 +2440,7 @@ def surface(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down
23 changes: 12 additions & 11 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AspectOptions,
ExtendOptions,
HueStyleOptions,
NormOptions,
ScaleOptions,
)
from xarray.plot.facetgrid import FacetGrid
Expand Down Expand Up @@ -183,7 +184,7 @@ def newplotfunc(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
robust: bool | None = None,
Expand Down Expand Up @@ -345,7 +346,7 @@ def quiver( # type: ignore[misc,unused-ignore] # None is hashable :(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -382,7 +383,7 @@ def quiver(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -419,7 +420,7 @@ def quiver(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -496,7 +497,7 @@ def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -533,7 +534,7 @@ def streamplot(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -570,7 +571,7 @@ def streamplot(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -783,7 +784,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -824,7 +825,7 @@ def scatter(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -865,7 +866,7 @@ def scatter(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -906,7 +907,7 @@ def scatter(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs: Any,
Expand Down
24 changes: 23 additions & 1 deletion xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import AspectOptions, ScaleOptions
from xarray.core.types import AspectOptions, NormOptions, ScaleOptions

try:
import matplotlib.pyplot as plt
Expand All @@ -58,6 +58,26 @@
_LINEWIDTH_RANGE = (1.5, 1.5, 6.0)


def _make_norm_from_string(
norm: NormOptions,
) -> type[Normalize]:
"""
Get norm from string.

Examples
--------
>>> _make_norm_from_string("log")
<class 'matplotlib.colors.LogScaleNorm'>

"""
from matplotlib.colors import Normalize, make_norm_from_scale
from matplotlib.scale import scale_factory

scale = type(scale_factory(norm, None)) # type: ignore [arg-type] # mpl issue, use of ax is discouraged

return make_norm_from_scale(scale, Normalize)


def _determine_extend(calc_data, vmin, vmax):
extend_min = calc_data.min() < vmin
extend_max = calc_data.max() > vmax
Expand Down Expand Up @@ -264,6 +284,8 @@ def _determine_cmap_params(

# now check norm and harmonize with vmin, vmax
if norm is not None:
norm = _make_norm_from_string(norm)() if isinstance(norm, str) else norm

if norm.vmin is None:
norm.vmin = vmin
else:
Expand Down
Loading