From 0831ab65faf5e4a9025284241ec9d93fb4aaa0a3 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Sat, 28 Feb 2026 08:14:35 -0500 Subject: [PATCH] unify_parallel_reduce_ranges --- test/backend/test_arange.py | 34 ++++++++++++++++++++++++++++++++++ test/backend/test_schedule.py | 35 +++++++++++++++++++++++++++++++++++ tinygrad/codegen/simplify.py | 18 ++++++++++++++++++ 3 files changed, 87 insertions(+) diff --git a/test/backend/test_arange.py b/test/backend/test_arange.py index a6b62bfe248a3..64ea5a11e1cb6 100644 --- a/test/backend/test_arange.py +++ b/test/backend/test_arange.py @@ -43,6 +43,40 @@ def test_tri_complexity(self): DSET, DDIM = 2048, 32 class TestIndexing(unittest.TestCase): + def test_arange_getitem(self): + dataset = Tensor.rand(64, 32).realize() + with Context(NOOPT=1): + GlobalCounters.reset() + out = dataset[Tensor.arange(64)] + sched = out.schedule() + self.assertEqual(len(sched), 1) + run_schedule(sched) + assert GlobalCounters.global_ops < 64*64, f"too many ops {GlobalCounters.global_ops}" + np.testing.assert_allclose(out.numpy(), dataset.numpy()) + + def test_two_arange_getitems(self): + a = Tensor.rand(64, 32).realize() + b = Tensor.rand(64, 32).realize() + with Context(NOOPT=1): + GlobalCounters.reset() + out = a[Tensor.arange(64)] + b[Tensor.arange(64)] + sched = out.schedule() + self.assertEqual(len(sched), 1) + run_schedule(sched) + assert GlobalCounters.global_ops < 64*64, f"too many ops {GlobalCounters.global_ops}" + np.testing.assert_allclose(out.numpy(), a.numpy() + b.numpy()) + + def test_arange_arange_getitem(self): + a = Tensor.rand(64, 64).realize() + with Context(NOOPT=1): + GlobalCounters.reset() + out = a[Tensor.arange(64), Tensor.arange(64)] + sched = out.schedule() + self.assertEqual(len(sched), 1) + run_schedule(sched) + assert GlobalCounters.global_ops < 64*64, f"too many ops {GlobalCounters.global_ops}" + np.testing.assert_allclose(out.numpy(), np.diag(a.numpy())) + def test_arange_2_reduce(self): needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous() needle[1337] = 1 diff --git a/test/backend/test_schedule.py b/test/backend/test_schedule.py index 260a2b33d233e..204ba9f5c7155 100644 --- a/test/backend/test_schedule.py +++ b/test/backend/test_schedule.py @@ -1094,6 +1094,41 @@ def test_self_assign_no_empty_kernel(self): run_schedule(check_schedule(a, 0, filter_sink=False)) self.assertListEqual(a.tolist(), [[1.]*shape[1]]*shape[0]) + def test_parallel_reduces_share_ops(self): + """Three independent reduces over the same dimension should have similar ops to one reduce.""" + X = Tensor.rand(256, 1000).realize() + Y = Tensor.randint(256, low=0, high=10).realize() + GlobalCounters.reset() + with Context(SPLIT_REDUCEOP=0): + X.sparse_categorical_crossentropy(Y, label_smoothing=0.1).realize() + ops_builtin = GlobalCounters.global_ops + def scc2(self, Y, ignore_index=-1, label_smoothing=0.0): + log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) + arange = Tensor.arange(log_probs.shape[0], device=self.device) + y = log_probs[arange, Y.flatten()] * loss_mask.unsqueeze(0) + smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask.unsqueeze(0)).sum() + return -((1 - label_smoothing) * y.sum() + smoothing) / loss_mask.sum() + GlobalCounters.reset() + with Context(SPLIT_REDUCEOP=0): + scc2(X, Y, label_smoothing=0.1).realize() + ops_scc2 = GlobalCounters.global_ops + self.assertLess(ops_scc2 / ops_builtin, 2.0, f"ops ratio too high: {ops_scc2}/{ops_builtin} = {ops_scc2/ops_builtin:.1f}x") + + def test_two_parallel_sums(self): + a = Tensor.rand(64, 128).realize() + with Context(SPLIT_REDUCEOP=0): + out = a.sum(axis=1) + a.mean(axis=1) + np.testing.assert_allclose(out.numpy(), a.numpy().sum(axis=1) + a.numpy().mean(axis=1), atol=1e-4) + + def test_sequential_reduces_stay_separate(self): + a = Tensor.rand(32, 50).realize() + a_np = a.numpy() + with Context(SPLIT_REDUCEOP=0): + out = a.log_softmax(axis=-1) + a_shifted = a_np - a_np.max(axis=-1, keepdims=True) + expected = a_shifted - np.log(np.exp(a_shifted).sum(axis=-1, keepdims=True)) + np.testing.assert_allclose(out.numpy(), expected, atol=1e-5) + class TestLimitBufs(unittest.TestCase): @unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI") def test_limit_bufs_with_var(self): diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index 1f88c4a09d5a7..66125a0fbf84c 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -36,8 +36,26 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None: u = nidx return u +def unify_parallel_reduce_ranges(sink:UOp) -> UOp|None: + """Unify ranges of independent parallel reduces with same sizes so GROUPTOP applies consistently.""" + # group reduces by all their ranges: reduce range sizes (to unify) + output ranges (must match exactly) + groups: dict[tuple, list[UOp]] = {} + for u in sink.toposort(): + if u.op is Ops.REDUCE: + key = (tuple(r.src[0].arg for r in u.src[1:]), tuple(sorted(id(r) for r in set(u.src[0].ranges) - set(u.src[1:])))) + groups.setdefault(key, []).append(u) + range_subs: dict[UOp, UOp] = {} + for reds in groups.values(): + # TODO: pick largest independent subset instead of skipping whole group + if any(r2 in r1.backward_slice for r1 in reds for r2 in reds if r1 is not r2): continue + for red in reds[1:]: + for r, c in zip(red.src[1:], reds[0].src[1:]): + if r is not c: range_subs[r] = c + return sink.substitute(range_subs) if range_subs else None + pm_simplify_ranges = PatternMatcher([ (UPat((Ops.END, Ops.REDUCE), name="u"), simplify_merge_adjacent), + (UPat(Ops.SINK, name="sink"), unify_parallel_reduce_ranges), ]) def mark_range_mod(ctx:dict[UOp, UOp|None], r:UOp, c:UOp) -> None: