@@ -914,66 +914,84 @@ def _get_border_axes(self) -> dict[str, list[paxes.Axes]]:
914914 """
915915
916916 gs = self .gridspec
917- all_axes = self .axes
917+
918+ # Skip colorbars or panels etc
919+ all_axes = [axi for axi in self .axes if axi .number is not None ]
918920
919921 # Handle empty cases
920922 nrows , ncols = gs .nrows , gs .ncols
923+ border_axes = dict (top = [], bottom = [], left = [], right = [])
921924 if nrows == 0 or ncols == 0 or not all_axes :
922- return dict (top = [], bottom = [], left = [], right = [])
925+ return border_axes
926+ # We cannot use the gridspec on the axes as it
927+ # is modified when a colorbar is added. Use self.gridspec
928+ # as a reference.
929+ # Reconstruct the grid based on axis locations. Note that
930+ # spanning axes will fit into one of the boxes. Check
931+ # this with unittest to see how empty axes are handles
932+ grid = np .zeros ((gs .nrows , gs .ncols ))
933+ for axi in all_axes :
934+ # Infer coordinate from grdispec
935+ spec = axi .get_subplotspec ()
936+ spans = spec ._get_rows_columns ()
937+ rowspans = spans [:2 ]
938+ colspans = spans [- 2 :]
939+
940+ grid [
941+ rowspans [0 ] : rowspans [1 ] + 1 ,
942+ colspans [0 ] : colspans [1 ] + 1 ,
943+ ] = axi .number
944+ directions = {
945+ "left" : (0 , - 1 ),
946+ "right" : (0 , 1 ),
947+ "top" : (- 1 , 0 ),
948+ "bottom" : (1 , 0 ),
949+ }
923950
924- # Find occupied grid cells and valid axes
925- occupied_cells = set ()
926- axes_with_spec = []
951+ def is_border (pos , grid , target , direction ):
952+ x , y = pos
953+ # Check if we are at an edge of the grid (out-of-bounds).
954+ if x < 0 :
955+ return True
956+ elif x > grid .shape [0 ] - 1 :
957+ return True
958+
959+ if y < 0 :
960+ return True
961+ elif y > grid .shape [1 ] - 1 :
962+ return True
963+
964+ # Check if we reached a plot or an internal edge
965+ if grid [x , y ] != target and grid [x , y ] > 0 :
966+ return False
967+ if grid [x , y ] == 0 :
968+ return True
969+ dx , dy = direction
970+ new_pos = (x + dx , y + dy )
971+ return is_border (new_pos , grid , target , direction )
972+
973+ from itertools import product
927974
928975 for axi in all_axes :
929976 spec = axi .get_subplotspec ()
930- if spec is not None :
931- axes_with_spec .append ((axi , spec ))
932- r0 , r1 = spec .rowspan .start , spec .rowspan .stop
933- c0 , c1 = spec .colspan .start , spec .colspan .stop
934- for r in range (r0 , r1 ):
935- for c in range (c0 , c1 ):
936- occupied_cells .add ((r , c ))
937-
938- if not axes_with_spec :
939- return dict (top = [], bottom = [], left = [], right = [])
940-
941- # Initialize border axes sets
942- border_axes_sets = dict (top = set (), bottom = set (), left = set (), right = set ())
943-
944- # Check each axis against border criteria
945- for axi , spec in axes_with_spec :
946- r0 , r1 = spec .rowspan .start , spec .rowspan .stop
947- c0 , c1 = spec .colspan .start , spec .colspan .stop
948-
949- # Check top border
950- if r0 == 0 or (
951- r0 == 1 and any ((0 , c ) not in occupied_cells for c in range (c0 , c1 ))
952- ):
953- border_axes_sets ["top" ].add (axi )
954-
955- # Check bottom border
956- if r1 == nrows or (
957- r1 == nrows - 1
958- and any ((nrows - 1 , c ) not in occupied_cells for c in range (c0 , c1 ))
959- ):
960- border_axes_sets ["bottom" ].add (axi )
961-
962- # Check left border
963- if c0 == 0 or (
964- c0 == 1 and any ((r , 0 ) not in occupied_cells for r in range (r0 , r1 ))
965- ):
966- border_axes_sets ["left" ].add (axi )
967-
968- # Check right border
969- if c1 == ncols or (
970- c1 == ncols - 1
971- and any ((r , ncols - 1 ) not in occupied_cells for r in range (r0 , r1 ))
972- ):
973- border_axes_sets ["right" ].add (axi )
974-
975- # Convert sets to lists
976- return {key : list (val ) for key , val in border_axes_sets .items ()}
977+ spans = spec ._get_rows_columns ()
978+ rowspan = spans [:2 ]
979+ colspan = spans [- 2 :]
980+ # Check all cardinal directions. When we find a
981+ # border for any starting conditions we break and
982+ # consider it a border. This could mean that for some
983+ # partial overlaps we consider borders that should
984+ # not be borders -- we are conservative in this
985+ # regard
986+ for direction , d in directions .items ():
987+ xs = range (rowspan [0 ], rowspan [1 ] + 1 )
988+ ys = range (colspan [0 ], colspan [1 ] + 1 )
989+ for x , y in product (xs , ys ):
990+ pos = (x , y )
991+ if is_border (pos = pos , grid = grid , target = axi .number , direction = d ):
992+ border_axes [direction ].append (axi )
993+ break
994+ return border_axes
977995
978996 def _get_align_coord (self , side , axs , includepanels = False ):
979997 """
0 commit comments