Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/backend/test_multitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def test_shard_empty(self):
assert GlobalCounters.kernel_count == 0
(X + X).realize()

def test_arange_shard_no_copy(self):
# pure computations should have no cross-device copies when sharded
sched = Tensor.arange(16).shard(devices_4, axis=0).schedule()
assert not any(si.ast.op is Ops.COPY for si in sched)
np.testing.assert_equal(Tensor.arange(16).shard(devices_4, axis=0).numpy(), np.arange(16))
sched = (Tensor.arange(16)*2+1).shard(devices_4, axis=0).schedule()
assert not any(si.ast.op is Ops.COPY for si in sched)
np.testing.assert_equal((Tensor.arange(16)*2+1).shard(devices_4, axis=0).numpy(), np.arange(16)*2+1)
# realized tensor keep copies since it has buffers
sched = Tensor.arange(16).realize().shard(devices_4, axis=0).schedule()
assert any(si.ast.op is Ops.COPY for si in sched)

def test_arange_shrink(self):
x = Tensor.arange(4)
self.assertEqual(x.shard(devices_2, 0).realize().shrink(((2, 4),)).tolist(), [2, 3])
Expand Down
8 changes: 8 additions & 0 deletions tinygrad/schedule/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:

# ***** multi rewrite MSELECT/MSTACK *****

def copy_pure_to_device(x:UOp, d:UOp):
if not isinstance(x.device, str) or d.arg.split(":")[0] != x.device.split(":")[0]: return None
if any(u.op in {Ops.BUFFER, Ops.PARAM} or (u.op is Ops.DEVICE and u.arg != x.device) for u in x.backward_slice_with_self): return None
if d.arg == x.device: return x
return x.substitute({UOp(Ops.DEVICE, arg=x.device): UOp(Ops.DEVICE, arg=d.arg)})

def mstack_early_shrink(ms:UOp, shrink:UOp):
ret:list[UOp] = []
def apply_shrink(s:UOp, i:int) -> UOp:
Expand All @@ -75,6 +81,8 @@ def apply_shrink(s:UOp, i:int) -> UOp:
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),
# eliminate COPY for pure (no BUFFER/PARAM, single device) computations by remapping to target device
(UPat(Ops.COPY, src=(UPat(name="x"), UPat(Ops.DEVICE, name="d"))), copy_pure_to_device),
# COPY_TO_ONE: if copying from multidevice to one, MSELECT the first (TODO: a little from each?)
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
x.mselect(0).copy_to_device(c.device) if isinstance(c.device, str) and isinstance(x.device, tuple) else None),
Expand Down
Loading