From bacb06c2c9dabf28fb28e4d31cb7c6008c940c75 Mon Sep 17 00:00:00 2001 From: Niru Maheswaranathan Date: Sat, 17 May 2025 11:11:19 -0700 Subject: [PATCH] Fix corr mode typo and add tests --- jetplot/images.py | 7 ++++--- tests/test_images.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 tests/test_images.py diff --git a/jetplot/images.py b/jetplot/images.py index 01a0bb2..90d06f7 100644 --- a/jetplot/images.py +++ b/jetplot/images.py @@ -28,8 +28,9 @@ def img( Args: img: array_like, The array to visualize. - mode: string, Either 'div' for a diverging image or 'seq' for - sequential (default: 'div'). + mode: string, One of 'div' for a diverging image, 'seq' for + sequential, 'cov' for covariance matrices, or 'corr' for + correlation matrices (default: 'div'). cmap: string, Colormap to use. aspect: string, Either 'equal' or 'auto' """ @@ -57,7 +58,7 @@ def img( cmap = "viridis" elif mode == "cov": vmin, vmax, cmap, cbar = 0, 1, "viridis", True - elif mode == "cov": + elif mode == "corr": vmin, vmax, cmap, cbar = -1, 1, "seismic", True else: raise ValueError("Unrecognized mode: '" + mode + "'") diff --git a/tests/test_images.py b/tests/test_images.py new file mode 100644 index 0000000..e810a73 --- /dev/null +++ b/tests/test_images.py @@ -0,0 +1,17 @@ +import numpy as np +from matplotlib import pyplot as plt +from jetplot import images + + +def test_img_corr_mode(): + data = np.eye(3) + fig, ax = plt.subplots() + im = images.img(data, mode="corr", fig=fig, ax=ax) + + # Check defaults for correlation mode + assert im.get_cmap().name == "seismic" + assert im.get_clim() == (-1, 1) + + # Colorbar should have been added + assert len(fig.axes) == 2 + plt.close(fig)