Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/backend/test_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,40 @@ def test_tri_complexity(self):
DSET, DDIM = 2048, 32

class TestIndexing(unittest.TestCase):
def test_arange_getitem(self):
dataset = Tensor.rand(64, 32).realize()
with Context(NOOPT=1):
GlobalCounters.reset()
out = dataset[Tensor.arange(64)]
sched = out.schedule()
self.assertEqual(len(sched), 1)
run_schedule(sched)
assert GlobalCounters.global_ops < 64*64, f"too many ops {GlobalCounters.global_ops}"
np.testing.assert_allclose(out.numpy(), dataset.numpy())

def test_two_arange_getitems(self):
a = Tensor.rand(64, 32).realize()
b = Tensor.rand(64, 32).realize()
with Context(NOOPT=1):
GlobalCounters.reset()
out = a[Tensor.arange(64)] + b[Tensor.arange(64)]
sched = out.schedule()
self.assertEqual(len(sched), 1)
run_schedule(sched)
assert GlobalCounters.global_ops < 64*64, f"too many ops {GlobalCounters.global_ops}"
np.testing.assert_allclose(out.numpy(), a.numpy() + b.numpy())

def test_arange_arange_getitem(self):
a = Tensor.rand(64, 64).realize()
with Context(NOOPT=1):
GlobalCounters.reset()
out = a[Tensor.arange(64), Tensor.arange(64)]
sched = out.schedule()
self.assertEqual(len(sched), 1)
run_schedule(sched)
assert GlobalCounters.global_ops < 64*64, f"too many ops {GlobalCounters.global_ops}"
np.testing.assert_allclose(out.numpy(), np.diag(a.numpy()))

def test_arange_2_reduce(self):
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
needle[1337] = 1
Expand Down
35 changes: 35 additions & 0 deletions test/backend/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,41 @@ def test_self_assign_no_empty_kernel(self):
run_schedule(check_schedule(a, 0, filter_sink=False))
self.assertListEqual(a.tolist(), [[1.]*shape[1]]*shape[0])

def test_parallel_reduces_share_ops(self):
"""Three independent reduces over the same dimension should have similar ops to one reduce."""
X = Tensor.rand(256, 1000).realize()
Y = Tensor.randint(256, low=0, high=10).realize()
GlobalCounters.reset()
with Context(SPLIT_REDUCEOP=0):
X.sparse_categorical_crossentropy(Y, label_smoothing=0.1).realize()
ops_builtin = GlobalCounters.global_ops
def scc2(self, Y, ignore_index=-1, label_smoothing=0.0):
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
arange = Tensor.arange(log_probs.shape[0], device=self.device)
y = log_probs[arange, Y.flatten()] * loss_mask.unsqueeze(0)
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask.unsqueeze(0)).sum()
return -((1 - label_smoothing) * y.sum() + smoothing) / loss_mask.sum()
GlobalCounters.reset()
with Context(SPLIT_REDUCEOP=0):
scc2(X, Y, label_smoothing=0.1).realize()
ops_scc2 = GlobalCounters.global_ops
self.assertLess(ops_scc2 / ops_builtin, 2.0, f"ops ratio too high: {ops_scc2}/{ops_builtin} = {ops_scc2/ops_builtin:.1f}x")

def test_two_parallel_sums(self):
a = Tensor.rand(64, 128).realize()
with Context(SPLIT_REDUCEOP=0):
out = a.sum(axis=1) + a.mean(axis=1)
np.testing.assert_allclose(out.numpy(), a.numpy().sum(axis=1) + a.numpy().mean(axis=1), atol=1e-4)

def test_sequential_reduces_stay_separate(self):
a = Tensor.rand(32, 50).realize()
a_np = a.numpy()
with Context(SPLIT_REDUCEOP=0):
out = a.log_softmax(axis=-1)
a_shifted = a_np - a_np.max(axis=-1, keepdims=True)
expected = a_shifted - np.log(np.exp(a_shifted).sum(axis=-1, keepdims=True))
np.testing.assert_allclose(out.numpy(), expected, atol=1e-5)

class TestLimitBufs(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
def test_limit_bufs_with_var(self):
Expand Down
18 changes: 18 additions & 0 deletions tinygrad/codegen/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,26 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None:
u = nidx
return u

def unify_parallel_reduce_ranges(sink:UOp) -> UOp|None:
"""Unify ranges of independent parallel reduces with same sizes so GROUPTOP applies consistently."""
# group reduces by all their ranges: reduce range sizes (to unify) + output ranges (must match exactly)
groups: dict[tuple, list[UOp]] = {}
for u in sink.toposort():
if u.op is Ops.REDUCE:
key = (tuple(r.src[0].arg for r in u.src[1:]), tuple(sorted(id(r) for r in set(u.src[0].ranges) - set(u.src[1:]))))
groups.setdefault(key, []).append(u)
range_subs: dict[UOp, UOp] = {}
for reds in groups.values():
# TODO: pick largest independent subset instead of skipping whole group
if any(r2 in r1.backward_slice for r1 in reds for r2 in reds if r1 is not r2): continue
for red in reds[1:]:
for r, c in zip(red.src[1:], reds[0].src[1:]):
if r is not c: range_subs[r] = c
return sink.substitute(range_subs) if range_subs else None

pm_simplify_ranges = PatternMatcher([
(UPat((Ops.END, Ops.REDUCE), name="u"), simplify_merge_adjacent),
(UPat(Ops.SINK, name="sink"), unify_parallel_reduce_ranges),
])

def mark_range_mod(ctx:dict[UOp, UOp|None], r:UOp, c:UOp) -> None:
Expand Down
Loading