From e0c5e9ca929b63f5f98f88c8ba7b380b364edfac Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Fri, 27 Feb 2026 20:48:53 -0500 Subject: [PATCH 1/2] remove ReduceContext.range_to_ends [pr] make merge_reduce_ends pure. this state is causing issue when introducing more reduce merging rewrites --- tinygrad/codegen/late/devectorizer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 7f866cd9c8f25..6ddccedfced7a 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -1,7 +1,7 @@ from typing import Any, cast import functools, itertools from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate @@ -308,8 +308,6 @@ def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp): @dataclass class ReduceContext: acc_num: int = 0 - # track ENDs by range for merging parallel reduces - range_to_ends: dict[tuple[UOp, ...], list[UOp]] = field(default_factory=dict) def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]: # if this has a horizontal reduction component, do that first @@ -336,12 +334,14 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) if len(reduce_range) == 0: return ret end = acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range) - ctx.range_to_ends.setdefault(reduce_range, []).append(end) return acc.after(end).index(UOp.const(dtypes.int, 0)) def merge_reduce_ends(ctx:ReduceContext, sink:UOp): # merge ENDs that share the same range - subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in ctx.range_to_ends.items() if len(ends) > 1 for e in ends} + range_to_ends: dict[tuple[UOp, ...], list[UOp]] = {} + for u in sink.backward_slice: + if u.op is Ops.END: range_to_ends.setdefault(u.src[1:], []).append(u) + subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in range_to_ends.items() if len(ends) > 1 for e in ends} return sink.substitute(subs) if subs else None pm_reduce = PatternMatcher([ From 55d4c90974eb647c1255bcb1434c9e090b6657ab Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Fri, 27 Feb 2026 21:57:18 -0500 Subject: [PATCH 2/2] tag --- tinygrad/codegen/late/devectorizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 6ddccedfced7a..230d5e357c372 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -333,14 +333,14 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): ctx.acc_num += 1 ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) if len(reduce_range) == 0: return ret - end = acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range) + end = acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range).rtag("mergeable") return acc.after(end).index(UOp.const(dtypes.int, 0)) def merge_reduce_ends(ctx:ReduceContext, sink:UOp): - # merge ENDs that share the same range + # merge ENDs that share the same range (only those created by reduce_to_acc) range_to_ends: dict[tuple[UOp, ...], list[UOp]] = {} for u in sink.backward_slice: - if u.op is Ops.END: range_to_ends.setdefault(u.src[1:], []).append(u) + if u.op is Ops.END and u.tag == "mergeable": range_to_ends.setdefault(u.src[1:], []).append(u) subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in range_to_ends.items() if len(ends) > 1 for e in ends} return sink.substitute(subs) if subs else None