diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 7f866cd9c8f25..230d5e357c372 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -1,7 +1,7 @@ from typing import Any, cast import functools, itertools from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate @@ -308,8 +308,6 @@ def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp): @dataclass class ReduceContext: acc_num: int = 0 - # track ENDs by range for merging parallel reduces - range_to_ends: dict[tuple[UOp, ...], list[UOp]] = field(default_factory=dict) def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]: # if this has a horizontal reduction component, do that first @@ -335,13 +333,15 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): ctx.acc_num += 1 ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) if len(reduce_range) == 0: return ret - end = acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range) - ctx.range_to_ends.setdefault(reduce_range, []).append(end) + end = acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range).rtag("mergeable") return acc.after(end).index(UOp.const(dtypes.int, 0)) def merge_reduce_ends(ctx:ReduceContext, sink:UOp): - # merge ENDs that share the same range - subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in ctx.range_to_ends.items() if len(ends) > 1 for e in ends} + # merge ENDs that share the same range (only those created by reduce_to_acc) + range_to_ends: dict[tuple[UOp, ...], list[UOp]] = {} + for u in sink.backward_slice: + if u.op is Ops.END and u.tag == "mergeable": range_to_ends.setdefault(u.src[1:], []).append(u) + subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in range_to_ends.items() if len(ends) > 1 for e in ends} return sink.substitute(subs) if subs else None pm_reduce = PatternMatcher([