From cc806ab74260cf3bdfca3913d0b262cc22c484c2 Mon Sep 17 00:00:00 2001 From: FumeDev Date: Tue, 30 Apr 2024 04:07:13 +0000 Subject: [PATCH] Added new tests for facetgrid in test_axisgrid.py. --- tests/test_axisgrid.py | 50 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/test_axisgrid.py b/tests/test_axisgrid.py index 6470edfa4f..b092d1fc08 100644 --- a/tests/test_axisgrid.py +++ b/tests/test_axisgrid.py @@ -1878,3 +1878,53 @@ def test_ax_warning(self, long_df): with pytest.warns(UserWarning): g = ag.jointplot(data=long_df, x="x", y="y", ax=ax) assert g.ax_joint.collections +def test_facetgrid_col_wrap_with_sharey(): + import pandas as pd + import numpy as np + import seaborn as sns + import matplotlib.pyplot as plt + + # Generate data + data = pd.DataFrame({ + "x": np.random.randn(100), + "y": np.random.rand(100), + "cat": np.tile(np.array(["A", "B", "C", "D"]), 25) + }) + + # Create FacetGrid with col_wrap and sharey + g = sns.FacetGrid(data, col="cat", col_wrap=2, sharey='row', height=3) + g.map(plt.scatter, "x", "y") + + # Collect yAxis limits to verify correct sharing + first_col_ys = [ax.get_shared_y_axes().get_siblings(ax)[0].get_ylim() for ax in g.axes[:2]] + second_col_ys = [ax.get_shared_y_axes().get_siblings(ax)[0].get_ylim() for ax in g.axes[2:]] + + # Assertions to check if sharey='row' is working as expected + assert first_col_ys[0] == first_col_ys[1], "First column should share y-axis" + assert second_col_ys[0] == second_col_ys[1], "Second column should share y-axis" + assert first_col_ys[0] != second_col_ys[0], "Different columns should not share y-axis between them" +def test_facetgrid_col_wrap_with_sharex(): + import pandas as pd + import numpy as np + import seaborn as sns + import matplotlib.pyplot as plt + + # Generate data + data = pd.DataFrame({ + "x": np.random.randn(100), + "y": np.random.rand(100), + "cat": np.tile(np.array(["A", "B", "C", "D"]), 25) + }) + + # Create FacetGrid with col_wrap and sharex + g = sns.FacetGrid(data, col="cat", col_wrap=2, sharex='col', height=3) + g.map(plt.scatter, "x", "y") + + # Collect xAxis limits to verify correct sharing + top_row_xs = [ax.get_shared_x_axes().get_siblings(ax)[0].get_xlim() for ax in g.axes[::2]] + bottom_row_xs = [ax.get_shared_x_axes().get_siblings(ax)[0].get_xlim() for ax in g.axes[1::2]] + + # Assertions to check if sharex='col' is working as expected + assert top_row_xs[0] == top_row_xs[1], "Top row should share x-axis" + assert bottom_row_xs[0] == bottom_row_xs[1], "Bottom row should share x-axis" + assert top_row_xs[0] != bottom_row_xs[0], "Top and bottom rows should not share x-axis"