@@ -515,20 +515,49 @@ def update_allow_flattened(cls, shape: Sequence[sympy.Expr]) -> None:
515515 break
516516
517517 def compact_shape (self , shapes : list [CompactedShape ]) -> list [CompactedShape ]:
518+ env = CompileEnvironment .current ()
519+ # Filter out unit-sized blocks that don't need compacting
520+ compact_block_ids = [
521+ block_id
522+ for block_id in self .block_ids
523+ if not (
524+ isinstance (env .block_sizes [block_id ].size , int )
525+ and env .block_sizes [block_id ].size == 1
526+ )
527+ ]
528+ if not compact_block_ids :
529+ return shapes
530+
518531 output = []
519532 shape_queue = collections .deque (shapes )
520533 while shape_queue :
521534 shape = shape_queue .popleft ()
522- if len (shape .block_ids ) != 1 or shape .block_ids [0 ] not in self .block_ids :
535+ # Check if this starts our flattened sequence
536+ if len (shape .block_ids ) != 1 or shape .block_ids [0 ] != compact_block_ids [0 ]:
523537 output .append (shape )
524538 continue
525- assert shape .block_ids [0 ] == self .block_ids [0 ]
526- for expected in self .block_ids [1 :]:
527- new_shape = shape_queue .popleft ()
528- assert len (new_shape .block_ids ) == 1
529- assert new_shape .block_ids [0 ] == expected
530- shape = shape .combine (new_shape )
531- output .append (shape )
539+
540+ # Try to collect the full sequence
541+ group_shapes = [shape ]
542+ found_complete_sequence = True
543+ for expected in compact_block_ids [1 :]:
544+ if (
545+ shape_queue
546+ and len (shape_queue [0 ].block_ids ) == 1
547+ and shape_queue [0 ].block_ids [0 ] == expected
548+ ):
549+ group_shapes .append (shape_queue .popleft ())
550+ else :
551+ # Partial match - don't combine
552+ found_complete_sequence = False
553+ output .extend (group_shapes )
554+ break
555+
556+ if found_complete_sequence :
557+ # Full match - combine into one
558+ for s in group_shapes [1 :]:
559+ shape = shape .combine (s )
560+ output .append (shape )
532561 return output
533562
534563
0 commit comments