Skip to content

Commit ade3cd4

Browse files
committed
fix
1 parent a71d7c1 commit ade3cd4

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

helion/_compiler/tile_strategy.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -515,20 +515,49 @@ def update_allow_flattened(cls, shape: Sequence[sympy.Expr]) -> None:
515515
break
516516

517517
def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:
518+
env = CompileEnvironment.current()
519+
# Filter out unit-sized blocks that don't need compacting
520+
compact_block_ids = [
521+
block_id
522+
for block_id in self.block_ids
523+
if not (
524+
isinstance(env.block_sizes[block_id].size, int)
525+
and env.block_sizes[block_id].size == 1
526+
)
527+
]
528+
if not compact_block_ids:
529+
return shapes
530+
518531
output = []
519532
shape_queue = collections.deque(shapes)
520533
while shape_queue:
521534
shape = shape_queue.popleft()
522-
if len(shape.block_ids) != 1 or shape.block_ids[0] not in self.block_ids:
535+
# Check if this starts our flattened sequence
536+
if len(shape.block_ids) != 1 or shape.block_ids[0] != compact_block_ids[0]:
523537
output.append(shape)
524538
continue
525-
assert shape.block_ids[0] == self.block_ids[0]
526-
for expected in self.block_ids[1:]:
527-
new_shape = shape_queue.popleft()
528-
assert len(new_shape.block_ids) == 1
529-
assert new_shape.block_ids[0] == expected
530-
shape = shape.combine(new_shape)
531-
output.append(shape)
539+
540+
# Try to collect the full sequence
541+
group_shapes = [shape]
542+
found_complete_sequence = True
543+
for expected in compact_block_ids[1:]:
544+
if (
545+
shape_queue
546+
and len(shape_queue[0].block_ids) == 1
547+
and shape_queue[0].block_ids[0] == expected
548+
):
549+
group_shapes.append(shape_queue.popleft())
550+
else:
551+
# Partial match - don't combine
552+
found_complete_sequence = False
553+
output.extend(group_shapes)
554+
break
555+
556+
if found_complete_sequence:
557+
# Full match - combine into one
558+
for s in group_shapes[1:]:
559+
shape = shape.combine(s)
560+
output.append(shape)
532561
return output
533562

534563

0 commit comments

Comments
 (0)