From 692cda5d87a46259e6870dadb3e7e5f3bd1cbef5 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Thu, 26 Mar 2026 21:59:56 +0100 Subject: [PATCH 1/7] support for col_wrap='auto' --- xarray/plot/accessor.py | 54 +++++++++++++++---------------- xarray/plot/dataarray_plot.py | 60 +++++++++++++++++++---------------- xarray/plot/dataset_plot.py | 30 ++++++++++-------- xarray/plot/facetgrid.py | 28 ++++++++-------- 4 files changed, 89 insertions(+), 83 deletions(-) diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index 9db4ae4e3f7..2b4c28a9027 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -192,7 +192,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -232,7 +232,7 @@ def scatter( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -272,7 +272,7 @@ def scatter( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -311,7 +311,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -350,7 +350,7 @@ def imshow( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -389,7 +389,7 @@ def imshow( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -432,7 +432,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -471,7 +471,7 @@ def contour( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -510,7 +510,7 @@ def contour( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -553,7 +553,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -592,7 +592,7 @@ def contourf( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -631,7 +631,7 @@ def contourf( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -674,7 +674,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -713,7 +713,7 @@ def pcolormesh( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -752,7 +752,7 @@ def pcolormesh( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -795,7 +795,7 @@ def surface( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -834,7 +834,7 @@ def surface( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -873,7 +873,7 @@ def surface( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -940,7 +940,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -980,7 +980,7 @@ def scatter( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -1020,7 +1020,7 @@ def scatter( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -1062,7 +1062,7 @@ def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -1098,7 +1098,7 @@ def quiver( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -1134,7 +1134,7 @@ def quiver( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -1174,7 +1174,7 @@ def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -1210,7 +1210,7 @@ def streamplot( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -1246,7 +1246,7 @@ def streamplot( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index e06a0b7187d..e672281446e 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -228,7 +228,7 @@ def plot( *, row: Hashable | None = None, col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, ax: Axes | None = None, hue: Hashable | None = None, subplot_kws: dict[str, Any] | None = None, @@ -255,8 +255,10 @@ def plot( If passed, make row faceted plots on this dimension name. col : Hashable or None, optional If passed, make column faceted plots on this dimension name. - col_wrap : int or None, optional - Use together with ``col`` to wrap faceted plots. + col_wrap : int, None or "auto", optional + "Wrap" the grid for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. + If "auto" make the grid as square as possible. ax : matplotlib axes object, optional Axes on which to plot. By default, use the current axes. Mutually exclusive with ``size``, ``figsize`` and facets. @@ -740,8 +742,10 @@ def _plot1d(plotfunc): If passed, make row faceted plots on this dimension name. col : Hashable, optional If passed, make column faceted plots on this dimension name. -col_wrap : int, optional - Use together with ``col`` to wrap faceted plots +col_wrap : int, None or "auto", optional + "Wrap" the grid for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. + If "auto" make the grid as square as possible. ax : matplotlib axes object, optional If None, uses the current axis. Not applicable when using facets. figsize : Iterable[float] or None, optional @@ -849,7 +853,7 @@ def newplotfunc( linewidth: Hashable | None = None, row: Hashable | None = None, col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, @@ -1130,7 +1134,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -1171,7 +1175,7 @@ def scatter( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -1212,7 +1216,7 @@ def scatter( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -1308,8 +1312,10 @@ def _plot2d(plotfunc): If passed, make row faceted plots on this dimension name. col : Hashable or None, optional If passed, make column faceted plots on this dimension name. -col_wrap : int, optional - Use together with ``col`` to wrap faceted plots. +col_wrap : int, None or "auto", optional + "Wrap" the grid for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. + If "auto" make the grid as square as possible. xincrease : None, True, or False, optional Should the values on the *x* axis be increasing from left to right? If ``None``, use the default for the Matplotlib function. @@ -1420,7 +1426,7 @@ def newplotfunc( ax: Axes | None = None, row: Hashable | None = None, col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1675,7 +1681,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1715,7 +1721,7 @@ def imshow( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1755,7 +1761,7 @@ def imshow( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1892,7 +1898,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1932,7 +1938,7 @@ def contour( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1972,7 +1978,7 @@ def contour( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2025,7 +2031,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2065,7 +2071,7 @@ def contourf( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2105,7 +2111,7 @@ def contourf( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2158,7 +2164,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2198,7 +2204,7 @@ def pcolormesh( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2238,7 +2244,7 @@ def pcolormesh( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2342,7 +2348,7 @@ def surface( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2382,7 +2388,7 @@ def surface( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2422,7 +2428,7 @@ def surface( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 857e9170508..f6a9341b08b 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -4,7 +4,7 @@ import inspect import warnings from collections.abc import Callable, Hashable, Iterable -from typing import TYPE_CHECKING, Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload from xarray.plot import dataarray_plot from xarray.plot.facetgrid import _easy_facetgrid @@ -65,8 +65,10 @@ def _dsplot(plotfunc): If passed, make row faceted plots on this dimension name. col : Hashable or None, optional If passed, make column faceted plots on this dimension name. -col_wrap : int, optional - Use together with ``col`` to wrap faceted plots. +col_wrap : int, None or "auto", optional + "Wrap" the grid for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. + If "auto" make the grid as square as possible. ax : matplotlib axes object or None, optional If ``None``, use the current axes. Not applicable when using facets. figsize : Iterable[float] or None, optional @@ -169,7 +171,7 @@ def newplotfunc( hue_style: HueStyleOptions = None, row: Hashable | None = None, col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, @@ -336,7 +338,7 @@ def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -373,7 +375,7 @@ def quiver( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -410,7 +412,7 @@ def quiver( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -487,7 +489,7 @@ def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -524,7 +526,7 @@ def streamplot( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -561,7 +563,7 @@ def streamplot( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -767,7 +769,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -808,7 +810,7 @@ def scatter( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -849,7 +851,7 @@ def scatter( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -890,7 +892,7 @@ def scatter( ax: Axes | None = None, row: Hashable | None = None, col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 5da382c1177..08783278b9b 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -114,7 +114,7 @@ class FacetGrid(Generic[T_DataArrayOrSet]): _row_var: Hashable | None _ncol: int _col_var: Hashable | None - _col_wrap: int | None + _col_wrap: int | Literal["auto"] | None row_labels: list[Annotation | None] col_labels: list[Annotation | None] _x_var: None @@ -129,7 +129,7 @@ def __init__( data: T_DataArrayOrSet, col: Hashable | None = None, row: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, figsize: Iterable[float] | None = None, @@ -142,15 +142,16 @@ def __init__( ---------- data : DataArray or Dataset DataArray or Dataset to be plotted. - row, col : str + row, col : hashable or None, optional Dimension names that define subsets of the data, which will be drawn on separate facets in the grid. - col_wrap : int, optional - "Wrap" the grid the for the column variable after this number of columns, + col_wrap : int, None or "auto", optional + "Wrap" the grid for the column variable after this number of columns, adding rows if ``col_wrap`` is less than the number of facets. - sharex : bool, optional + If "auto" make the grid as square as possible. + sharex : bool, default: True If true, the facets will share *x* axes. - sharey : bool, optional + sharey : bool, default: True If true, the facets will share *y* axes. figsize : Iterable of float or None, optional A tuple (width, height) of the figure in inches. @@ -163,7 +164,6 @@ def __init__( subplot_kws : dict, optional Dictionary of keyword arguments for Matplotlib subplots (:py:func:`matplotlib.pyplot.subplots`). - """ import matplotlib.pyplot as plt @@ -198,13 +198,11 @@ def __init__( # Compute grid shape if single_group: nfacet = len(data[single_group]) - if col: - # idea - could add heuristic for nice shapes like 3x4 - ncol = nfacet - if row: - ncol = 1 - if col_wrap is not None: - # Overrides previous settings + if col_wrap == "auto": + ncol = int(np.ceil(np.sqrt(nfacet))) + elif col_wrap is None: + ncol = nfacet if col else 1 + else: ncol = col_wrap nrow = int(np.ceil(nfacet / ncol)) From a78eeb1a707f5abe49318106ee659a319decff01 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Thu, 26 Mar 2026 22:07:18 +0100 Subject: [PATCH 2/7] add tests --- xarray/tests/test_plot.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 5980d449dbb..9cac6a3736f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1668,6 +1668,23 @@ def test_convenient_facetgrid_4d(self) -> None: for ax in g.axs.flat: assert ax.has_data() + @pytest.mark.parametrize( + ["n", "ncols", "nrows"], + [ + pytest.param(1, 1, 1, id="1"), + pytest.param(2, 2, 1, id="2"), + pytest.param(4, 2, 2, id="4"), + pytest.param(6, 3, 2, id="6"), + pytest.param(8, 3, 3, id="8"), + ], + ) + def test_facetgrid_col_wrap_auto(self, n: int, ncols: int, nrows: int) -> None: + a = easy_array((10, 15, n)) + d = DataArray(a, dims=["y", "x", "z"]) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap="auto") + + assert_array_equal(g.axs.shape, [nrows, ncols]) + @pytest.mark.filterwarnings("ignore:This figure includes") def test_facetgrid_map_only_appends_mappables(self) -> None: a = easy_array((10, 15, 2, 3)) From ee53ea3b5b76d5cae2b3c0b22dd691d645927c45 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Thu, 26 Mar 2026 22:13:17 +0100 Subject: [PATCH 3/7] amend changelog --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 41679520930..b285562eec1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -17,6 +17,9 @@ New Features - Added ``inherit='all_coords'`` option to :py:meth:`DataTree.to_dataset` to inherit all parent coordinates, not just indexed ones (:issue:`10812`, :pull:`11230`). By `Alfonso Ladino `_. +- Support ``col_wrap='auto'`` in plots that will wrap the grid to be as square + as possible (:pull:`11266`). + By `Michael Niklas `_. Breaking Changes ~~~~~~~~~~~~~~~~ From 622536464cb66a9832d416c2812c6ee20dff4745 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 27 Mar 2026 09:43:24 +0100 Subject: [PATCH 4/7] align grid to figsize --- xarray/plot/dataarray_plot.py | 6 +++--- xarray/plot/dataset_plot.py | 2 +- xarray/plot/facetgrid.py | 11 +++++++++-- xarray/tests/test_plot.py | 21 ++++++++++++--------- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index e672281446e..9fead9934d6 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -258,7 +258,7 @@ def plot( col_wrap : int, None or "auto", optional "Wrap" the grid for the column variable after this number of columns, adding rows if ``col_wrap`` is less than the number of facets. - If "auto" make the grid as square as possible. + If "auto" align the grid to the figsize or keep it as square as possible. ax : matplotlib axes object, optional Axes on which to plot. By default, use the current axes. Mutually exclusive with ``size``, ``figsize`` and facets. @@ -745,7 +745,7 @@ def _plot1d(plotfunc): col_wrap : int, None or "auto", optional "Wrap" the grid for the column variable after this number of columns, adding rows if ``col_wrap`` is less than the number of facets. - If "auto" make the grid as square as possible. + If "auto" align the grid to the figsize or keep it as square as possible. ax : matplotlib axes object, optional If None, uses the current axis. Not applicable when using facets. figsize : Iterable[float] or None, optional @@ -1315,7 +1315,7 @@ def _plot2d(plotfunc): col_wrap : int, None or "auto", optional "Wrap" the grid for the column variable after this number of columns, adding rows if ``col_wrap`` is less than the number of facets. - If "auto" make the grid as square as possible. + If "auto" align the grid to the figsize or keep it as square as possible. xincrease : None, True, or False, optional Should the values on the *x* axis be increasing from left to right? If ``None``, use the default for the Matplotlib function. diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index f6a9341b08b..bc51d1eee80 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -68,7 +68,7 @@ def _dsplot(plotfunc): col_wrap : int, None or "auto", optional "Wrap" the grid for the column variable after this number of columns, adding rows if ``col_wrap`` is less than the number of facets. - If "auto" make the grid as square as possible. + If "auto" align the grid to the figsize or keep it as square as possible. ax : matplotlib axes object or None, optional If ``None``, use the current axes. Not applicable when using facets. figsize : Iterable[float] or None, optional diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 08783278b9b..030047ae5d9 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -148,7 +148,7 @@ def __init__( col_wrap : int, None or "auto", optional "Wrap" the grid for the column variable after this number of columns, adding rows if ``col_wrap`` is less than the number of facets. - If "auto" make the grid as square as possible. + If "auto" align the grid to the figsize or keep it as square as possible. sharex : bool, default: True If true, the facets will share *x* axes. sharey : bool, default: True @@ -195,11 +195,18 @@ def __init__( else: raise ValueError("Pass a coordinate name as an argument for row or col") + # exhaust generators + figsize = None if figsize is None else tuple(s for s in figsize) + # Compute grid shape if single_group: nfacet = len(data[single_group]) if col_wrap == "auto": - ncol = int(np.ceil(np.sqrt(nfacet))) + # try to align the grid to the figsize. If figsize is unknown it gets + # computed from the grid, so lets make it square in this case. + faspect = 1 if figsize is None else (figsize[0] / figsize[1]) + # only wrap if > 3 images + ncol = int(np.ceil(np.sqrt(nfacet * faspect))) if nfacet > 3 else nfacet elif col_wrap is None: ncol = nfacet if col else 1 else: diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 9cac6a3736f..4a874921926 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1669,21 +1669,24 @@ def test_convenient_facetgrid_4d(self) -> None: assert ax.has_data() @pytest.mark.parametrize( - ["n", "ncols", "nrows"], + ["n", "figsize", "expected_shape"], [ - pytest.param(1, 1, 1, id="1"), - pytest.param(2, 2, 1, id="2"), - pytest.param(4, 2, 2, id="4"), - pytest.param(6, 3, 2, id="6"), - pytest.param(8, 3, 3, id="8"), + pytest.param(1, None, [1, 1], id="1"), + pytest.param(3, None, [1, 3], id="3"), # <4 should not be wrapped + pytest.param(6, None, [2, 3], id="6"), + pytest.param(8, None, [3, 3], id="8"), + pytest.param(8, [10, 5], [2, 4], id="8-aspect=2"), + pytest.param(8, [5, 10], [4, 2], id="8-aspect=0.5"), ], ) - def test_facetgrid_col_wrap_auto(self, n: int, ncols: int, nrows: int) -> None: + def test_facetgrid_col_wrap_auto( + self, n: int, figsize: None | tuple[int, int], expected_shape: tuple[int, int] + ) -> None: a = easy_array((10, 15, n)) d = DataArray(a, dims=["y", "x", "z"]) - g = self.plotfunc(d, x="x", y="y", col="z", col_wrap="auto") + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap="auto", figsize=figsize) - assert_array_equal(g.axs.shape, [nrows, ncols]) + assert_array_equal(g.axs.shape, expected_shape) @pytest.mark.filterwarnings("ignore:This figure includes") def test_facetgrid_map_only_appends_mappables(self) -> None: From 770ff55e8e6b4159de0034896ac90b200bd663bd Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 27 Mar 2026 09:47:13 +0100 Subject: [PATCH 5/7] simplify tuple gen --- xarray/plot/facetgrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 030047ae5d9..108fe0cc4bd 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -196,7 +196,7 @@ def __init__( raise ValueError("Pass a coordinate name as an argument for row or col") # exhaust generators - figsize = None if figsize is None else tuple(s for s in figsize) + figsize = None if figsize is None else tuple(figsize) # Compute grid shape if single_group: From 1b3b8c1f9e490356f41136d56b21f8d2642e10e7 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 27 Mar 2026 15:51:27 +0100 Subject: [PATCH 6/7] support extreme figsizes --- xarray/plot/facetgrid.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 108fe0cc4bd..5fb5f810121 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -203,10 +203,14 @@ def __init__( nfacet = len(data[single_group]) if col_wrap == "auto": # try to align the grid to the figsize. If figsize is unknown it gets - # computed from the grid, so lets make it square in this case. - faspect = 1 if figsize is None else (figsize[0] / figsize[1]) + # computed from the grid, so lets keep it as square as possible + faspect = 1 if figsize is None else figsize[0] / figsize[1] # only wrap if > 3 images - ncol = int(np.ceil(np.sqrt(nfacet * faspect))) if nfacet > 3 else nfacet + ncol = ( + min(nfacet, int(np.ceil(np.sqrt(nfacet * faspect)))) + if nfacet > 3 + else nfacet + ) elif col_wrap is None: ncol = nfacet if col else 1 else: From 31d2f51a4c8dc39e53560378f497d626cff7065c Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 27 Mar 2026 16:43:20 +0100 Subject: [PATCH 7/7] now even support aspect --- xarray/plot/facetgrid.py | 38 ++++++++++++++++++++++++++++---------- xarray/tests/test_plot.py | 26 +++++++++++++++++--------- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 5fb5f810121..2bdd11a9f59 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -56,6 +56,31 @@ def _nicetitle(coord, value, maxchar, template): return title +def _auto_grid( + nfacet: int, figsize: tuple[float, ...] | None, aspect: float +) -> tuple[int, int]: + + # Try to align the grid to the figsize. If figsize is unknown it gets + # computed from the grid, so lets keep it as square as possible + faspect = 1 if figsize is None else figsize[0] / figsize[1] + + # Only wrap if > 3 images + if nfacet <= 3: + return nfacet, 1 + + # Geometric ideal case + ncol = int(np.ceil(np.sqrt(nfacet * faspect / aspect))) + ncol = max(1, min(ncol, nfacet)) + nrow = int(np.ceil(nfacet / ncol)) + + # Reduce columns as long as we don't need more rows + # This eliminates empty slots in the last row if aspect < 1 + while ncol > 1 and (ncol - 1) * nrow >= nfacet: + ncol -= 1 + + return ncol, nrow + + T_FacetGrid = TypeVar("T_FacetGrid", bound="FacetGrid") @@ -202,20 +227,13 @@ def __init__( if single_group: nfacet = len(data[single_group]) if col_wrap == "auto": - # try to align the grid to the figsize. If figsize is unknown it gets - # computed from the grid, so lets keep it as square as possible - faspect = 1 if figsize is None else figsize[0] / figsize[1] - # only wrap if > 3 images - ncol = ( - min(nfacet, int(np.ceil(np.sqrt(nfacet * faspect)))) - if nfacet > 3 - else nfacet - ) + ncol, nrow = _auto_grid(nfacet, figsize, aspect) elif col_wrap is None: ncol = nfacet if col else 1 + nrow = int(np.ceil(nfacet / ncol)) else: ncol = col_wrap - nrow = int(np.ceil(nfacet / ncol)) + nrow = int(np.ceil(nfacet / ncol)) # Set the subplot kwargs subplot_kws = {} if subplot_kws is None else subplot_kws diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 4a874921926..cf8d0e2999e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1669,22 +1669,30 @@ def test_convenient_facetgrid_4d(self) -> None: assert ax.has_data() @pytest.mark.parametrize( - ["n", "figsize", "expected_shape"], + ["n", "figsize", "aspect", "expected_shape"], [ - pytest.param(1, None, [1, 1], id="1"), - pytest.param(3, None, [1, 3], id="3"), # <4 should not be wrapped - pytest.param(6, None, [2, 3], id="6"), - pytest.param(8, None, [3, 3], id="8"), - pytest.param(8, [10, 5], [2, 4], id="8-aspect=2"), - pytest.param(8, [5, 10], [4, 2], id="8-aspect=0.5"), + pytest.param(1, None, 1, [1, 1], id="1"), + pytest.param(3, None, 1, [1, 3], id="3"), # <4 should not be wrapped + pytest.param(6, None, 1, [2, 3], id="6"), + pytest.param(8, None, 1, [3, 3], id="8"), + pytest.param(8, [10, 5], 1, [2, 4], id="8-figaspect=2"), + pytest.param(8, [5, 10], 1, [4, 2], id="8-figaspect=0.5"), + pytest.param(8, None, 4, [4, 2], id="8-aspect=4"), + pytest.param(8, None, 0.25, [2, 4], id="8-aspect=0.25"), ], ) def test_facetgrid_col_wrap_auto( - self, n: int, figsize: None | tuple[int, int], expected_shape: tuple[int, int] + self, + n: int, + figsize: None | tuple[int, int], + aspect: int, + expected_shape: tuple[int, int], ) -> None: a = easy_array((10, 15, n)) d = DataArray(a, dims=["y", "x", "z"]) - g = self.plotfunc(d, x="x", y="y", col="z", col_wrap="auto", figsize=figsize) + g = self.plotfunc( + d, x="x", y="y", col="z", col_wrap="auto", figsize=figsize, aspect=aspect + ) assert_array_equal(g.axs.shape, expected_shape)