From 99f77a32fb440ff36b980dc1804126885e98e936 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 27 Mar 2026 12:11:58 +0000 Subject: [PATCH] Add RGB support to plot_array and unit test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `plot_array` now detects 3- or 4-channel arrays and skips colormap, norm, and colorbar — all of which are inappropriate for RGB images. A `test__plot_array_rgb` unit test covers the new code path. Co-Authored-By: Claude Sonnet 4.6 --- autoarray/plot/array.py | 31 +++++++++++++------ .../plot/test_structure_plotters.py | 21 +++++++++++++ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/autoarray/plot/array.py b/autoarray/plot/array.py index c5b86e41..74f128ee 100644 --- a/autoarray/plot/array.py +++ b/autoarray/plot/array.py @@ -133,6 +133,8 @@ def plot_array( if array is None or array.size == 0: return + is_rgb = array.ndim == 3 and array.shape[2] in (3, 4) + if colormap is None: from autoarray.plot.utils import _default_colormap colormap = _default_colormap() @@ -194,21 +196,30 @@ def plot_array( h, w = array.shape[:2] _box_aspect = (w / h) if h > 0 else 1.0 - im = ax.imshow( - array, - cmap=colormap, - norm=norm, - extent=extent, - aspect="auto", # image fills the axes box; box shape set below - origin=origin_imshow, - ) + if is_rgb: + im = ax.imshow( + array, + extent=extent, + aspect="auto", + origin=origin_imshow, + ) + else: + im = ax.imshow( + array, + cmap=colormap, + norm=norm, + extent=extent, + aspect="auto", # image fills the axes box; box shape set below + origin=origin_imshow, + ) # Shape the axes box to match the data so there is no surrounding # whitespace when the panel is embedded in a subplot grid. ax.set_aspect(_box_aspect, adjustable="box") - from autoarray.plot.utils import _apply_colorbar - _apply_colorbar(im, ax, cb_unit=cb_unit, is_subplot=not owns_figure) + if not is_rgb: + from autoarray.plot.utils import _apply_colorbar + _apply_colorbar(im, ax, cb_unit=cb_unit, is_subplot=not owns_figure) # --- overlays -------------------------------------------------------------- if array_overlay is not None: diff --git a/test_autoarray/structures/plot/test_structure_plotters.py b/test_autoarray/structures/plot/test_structure_plotters.py index 0f66bab9..fc139a18 100644 --- a/test_autoarray/structures/plot/test_structure_plotters.py +++ b/test_autoarray/structures/plot/test_structure_plotters.py @@ -145,3 +145,24 @@ def test__array_rgb( ) assert path.join(plot_path, "array_rgb.png") in plot_patch.paths + + +def test__plot_array_rgb( + array_2d_rgb_7x7, + plot_path, + plot_patch, +): + """ + `plot_array` (the high-level function) must handle `Array2DRGB` inputs without + applying a colormap, norm, or colorbar — all of which would raise errors or + produce nonsense for a 3-channel image. + """ + aplt.plot_array( + array=array_2d_rgb_7x7, + title="RGB Test", + output_path=plot_path, + output_filename="array_rgb_high_level", + output_format="png", + ) + + assert path.join(plot_path, "array_rgb_high_level.png") in plot_patch.paths