From a1951d5f553abb19e464508576bbd3104eb891d4 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Sun, 22 Feb 2026 20:12:08 -0500 Subject: [PATCH] multi arange compute on each device --- test/backend/test_multitensor.py | 12 ++++++++++++ tinygrad/schedule/multi.py | 8 ++++++++ 2 files changed, 20 insertions(+) diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index ebb1bd96608c0..9fedb6f6f1f5d 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -64,6 +64,18 @@ def test_shard_empty(self): assert GlobalCounters.kernel_count == 0 (X + X).realize() + def test_arange_shard_no_copy(self): + # pure computations should have no cross-device copies when sharded + sched = Tensor.arange(16).shard(devices_4, axis=0).schedule() + assert not any(si.ast.op is Ops.COPY for si in sched) + np.testing.assert_equal(Tensor.arange(16).shard(devices_4, axis=0).numpy(), np.arange(16)) + sched = (Tensor.arange(16)*2+1).shard(devices_4, axis=0).schedule() + assert not any(si.ast.op is Ops.COPY for si in sched) + np.testing.assert_equal((Tensor.arange(16)*2+1).shard(devices_4, axis=0).numpy(), np.arange(16)*2+1) + # realized tensor keep copies since it has buffers + sched = Tensor.arange(16).realize().shard(devices_4, axis=0).schedule() + assert any(si.ast.op is Ops.COPY for si in sched) + def test_arange_shrink(self): x = Tensor.arange(4) self.assertEqual(x.shard(devices_2, 0).realize().shrink(((2, 4),)).tolist(), [2, 3]) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 077ac47a154a7..4a0ffae22c92d 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -57,6 +57,12 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: # ***** multi rewrite MSELECT/MSTACK ***** +def copy_pure_to_device(x:UOp, d:UOp): + if not isinstance(x.device, str) or d.arg.split(":")[0] != x.device.split(":")[0]: return None + if any(u.op in {Ops.BUFFER, Ops.PARAM} or (u.op is Ops.DEVICE and u.arg != x.device) for u in x.backward_slice_with_self): return None + if d.arg == x.device: return x + return x.substitute({UOp(Ops.DEVICE, arg=x.device): UOp(Ops.DEVICE, arg=d.arg)}) + def mstack_early_shrink(ms:UOp, shrink:UOp): ret:list[UOp] = [] def apply_shrink(s:UOp, i:int) -> UOp: @@ -75,6 +81,8 @@ def apply_shrink(s:UOp, i:int) -> UOp: # BROADCAST: explicitly expand broadcast copies and combine with MSTACK (UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x: UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None), + # eliminate COPY for pure (no BUFFER/PARAM, single device) computations by remapping to target device + (UPat(Ops.COPY, src=(UPat(name="x"), UPat(Ops.DEVICE, name="d"))), copy_pure_to_device), # COPY_TO_ONE: if copying from multidevice to one, MSELECT the first (TODO: a little from each?) (UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x: x.mselect(0).copy_to_device(c.device) if isinstance(c.device, str) and isinstance(x.device, tuple) else None),