Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions tinygrad/codegen/late/devectorizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, cast
import functools, itertools
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate
Expand Down Expand Up @@ -308,8 +308,6 @@ def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
@dataclass
class ReduceContext:
acc_num: int = 0
# track ENDs by range for merging parallel reduces
range_to_ends: dict[tuple[UOp, ...], list[UOp]] = field(default_factory=dict)

def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]:
# if this has a horizontal reduction component, do that first
Expand All @@ -335,13 +333,15 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
if len(reduce_range) == 0: return ret
end = acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)
ctx.range_to_ends.setdefault(reduce_range, []).append(end)
end = acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range).rtag("mergeable")
return acc.after(end).index(UOp.const(dtypes.int, 0))

def merge_reduce_ends(ctx:ReduceContext, sink:UOp):
# merge ENDs that share the same range
subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in ctx.range_to_ends.items() if len(ends) > 1 for e in ends}
# merge ENDs that share the same range (only those created by reduce_to_acc)
range_to_ends: dict[tuple[UOp, ...], list[UOp]] = {}
for u in sink.backward_slice:
if u.op is Ops.END and u.tag == "mergeable": range_to_ends.setdefault(u.src[1:], []).append(u)
subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in range_to_ends.items() if len(ends) > 1 for e in ends}
return sink.substitute(subs) if subs else None

pm_reduce = PatternMatcher([
Expand Down
Loading