Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 76 additions & 12 deletions test/unit/test_assign.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
from tinygrad import dtypes, Device, Tensor, TinyJit, GlobalCounters, Variable
from tinygrad.uop.ops import Ops
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import temp, CI, CPU_LVP, Context
Expand Down Expand Up @@ -685,22 +685,86 @@ def test_chained_slice_copies(self):

def test_swap_slices(self):
"""Swap two non-overlapping slices - requires reading both before writing."""
# without .realize() on temps: values not captured before overwriting
# single realize: clone reads must complete before assign writes
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
left = buf[0:4].clone() # lazy - not captured yet
right = buf[4:8].clone() # lazy - not captured yet
buf[0:4].assign(right).realize() # this works
buf[4:8].assign(left).realize() # left now reads from modified buf!
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 5, 6, 7, 8]) # TODO: wrong! should be [5,6,7,8,1,2,3,4]
left = buf[0:4].clone()
right = buf[4:8].clone()
buf[0:4].assign(right)
buf[4:8].assign(left)
buf.realize()
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4])

# with .realize() on temps: values captured before writes
def test_derived_from_clone_before_assign(self):
"""Tensors derived from a clone should see pre-assign data, same as the clone itself."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
left = buf[0:4].clone()
derived = left + 10
right = buf[4:8].clone()
buf[0:4].assign(right)
Tensor.realize(buf, left, derived)
np.testing.assert_equal(left.numpy(), [1, 2, 3, 4])
np.testing.assert_equal(derived.numpy(), [11, 12, 13, 14])

def test_view_of_clone_before_assign(self):
"""A view of a clone should see the same pre-assign data as the clone."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
c = buf[0:4].clone()
v = c[0:2]
right = buf[4:8].clone()
buf[0:4].assign(right)
Tensor.realize(buf, c, v)
np.testing.assert_equal(c.numpy(), [1, 2, 3, 4])
np.testing.assert_equal(v.numpy(), [1, 2])

def test_gc_clone_derived(self):
"""GC'd intermediate: no named reference to the clone tensor, but its UOp still exists in derived's graph."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
derived = buf[0:4].clone() + 10
right = buf[4:8].clone()
buf[0:4].assign(right)
Tensor.realize(buf, derived)
np.testing.assert_equal(derived.numpy(), [11, 12, 13, 14])

def test_gc_clone_reshaped(self):
"""GC'd intermediate: clone().reshape() chain where the clone tensor is garbage collected."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
left = buf[0:4].clone().realize()
right = buf[4:8].clone().realize()
buf[0:4].assign(right).realize()
buf[4:8].assign(left).realize()
v = buf[0:4].clone().reshape(2, 2)
right = buf[4:8].clone()
buf[0:4].assign(right)
Tensor.realize(buf, v)
np.testing.assert_equal(v.numpy(), [[1, 2], [3, 4]])

def test_mixed_clone_and_direct(self):
"""Tensor reads both through clone (snapshot) and directly from buffer."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
c = buf[0:4].clone()
t = c + buf[4:8]
buf[0:4].assign(Tensor([9, 10, 11, 12]))
Tensor.realize(buf, c, t)
np.testing.assert_equal(c.numpy(), [1, 2, 3, 4])
np.testing.assert_equal(t.numpy(), [1+5, 2+6, 3+7, 4+8])

def test_copy_swap_slices(self):
"""Swap via cross-device copy — copy to another device is a copy barrier like clone."""
dev0, dev1 = Device.DEFAULT, f"{Device.DEFAULT}:1" if ":" not in Device.DEFAULT else Device.DEFAULT.rsplit(":",1)[0]+":1"
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
left, right = buf[0:4].to(dev1), buf[4:8].to(dev1)
buf[0:4].assign(right.to(dev0))
buf[4:8].assign(left.to(dev0))
buf.realize()
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4])

def test_copy_before_assign(self):
"""Copy to another device before assign should see pre-assign data."""
dev1 = f"{Device.DEFAULT}:1" if ":" not in Device.DEFAULT else Device.DEFAULT.rsplit(":",1)[0]+":1"
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
c = buf[0:4].to(dev1)
derived = c + 10
buf[0:4].assign(Tensor([9, 10, 11, 12]))
Tensor.realize(buf, c, derived)
np.testing.assert_equal(c.numpy(), [1, 2, 3, 4])
np.testing.assert_equal(derived.numpy(), [11, 12, 13, 14])

def test_reduction_after_partial_assign(self):
"""Reduction over buffer after partial assign - must see the assigned values."""
buf = Tensor.zeros(4, 4).contiguous().realize()
Expand Down
12 changes: 11 additions & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,17 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
assign_uop = self.uop.assign(x.uop)
base = self.uop.base
if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity():
_apply_map_to_tensors({base: base.after(assign_uop)}, name="Embed View Assign", walk=True)
# copy barriers (clone/copy to a different buffer) must read pre-assign data
buf, orig_map = base.buf_uop, {base: base.after(assign_uop)}
reverse_map: dict[UOp, UOp] = {}
for tref in list(all_tensors):
if (t:=tref()) is None: continue
for u in t.uop.backward_slice_with_self:
is_barrier = (u.op is Ops.ASSIGN and u.src[0].buf_uop is not buf) or u.op is Ops.COPY
if is_barrier and buf in u.backward_slice:
if (nc:=u.substitute(orig_map, walk=True)) is not u: reverse_map[nc] = u
_apply_map_to_tensors(orig_map, name="Embed View Assign", walk=True)
if reverse_map: _apply_map_to_tensors(reverse_map, name="Restore copy barrier readers", walk=True)
return self.replace(self._apply_uop(lambda *_: assign_uop, x))

def detach(self) -> Tensor:
Expand Down
Loading