Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions test/unit/test_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,21 +685,64 @@ def test_chained_slice_copies(self):

def test_swap_slices(self):
"""Swap two non-overlapping slices - requires reading both before writing."""
# without .realize() on temps: values not captured before overwriting
# single realize: WAR detection orders contiguous reads before assign writes
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
left = buf[0:4].contiguous() # lazy - not captured yet
right = buf[4:8].contiguous() # lazy - not captured yet
buf[0:4].assign(right).realize() # this works
buf[4:8].assign(left).realize() # left now reads from modified buf!
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 5, 6, 7, 8]) # TODO: wrong! should be [5,6,7,8,1,2,3,4]
left = buf[0:4].contiguous()
right = buf[4:8].contiguous()
buf[0:4].assign(right)
buf[4:8].assign(left)
buf.realize()
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4])

# with .realize() on temps: values captured before writes
def test_derived_from_contiguous_before_assign(self):
"""Tensors derived from a contiguous read should see pre-assign data, same as the contiguous itself."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
left = buf[0:4].contiguous().realize()
right = buf[4:8].contiguous().realize()
buf[0:4].assign(right).realize()
buf[4:8].assign(left).realize()
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4])
left = buf[0:4].contiguous()
derived = left + 10
right = buf[4:8].contiguous()
buf[0:4].assign(right)
Tensor.realize(buf, left, derived)
np.testing.assert_equal(left.numpy(), [1, 2, 3, 4])
np.testing.assert_equal(derived.numpy(), [11, 12, 13, 14])

def test_view_of_contiguous_before_assign(self):
"""A view of a contiguous read should see the same pre-assign data as the contiguous."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
c = buf[0:4].contiguous()
v = c[0:2]
right = buf[4:8].contiguous()
buf[0:4].assign(right)
Tensor.realize(buf, c, v)
np.testing.assert_equal(c.numpy(), [1, 2, 3, 4])
np.testing.assert_equal(v.numpy(), [1, 2])

def test_gc_contiguous_derived(self):
"""GC'd intermediate: no named reference to the contiguous tensor, but its UOp still exists in derived's graph."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
derived = buf[0:4].contiguous() + 10
right = buf[4:8].contiguous()
buf[0:4].assign(right)
Tensor.realize(buf, derived)
np.testing.assert_equal(derived.numpy(), [11, 12, 13, 14])

def test_gc_contiguous_reshaped(self):
"""GC'd intermediate: contiguous().reshape() chain where the contiguous tensor is garbage collected."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
v = buf[0:4].contiguous().reshape(2, 2)
right = buf[4:8].contiguous()
buf[0:4].assign(right)
Tensor.realize(buf, v)
np.testing.assert_equal(v.numpy(), [[1, 2], [3, 4]])

def test_mixed_contiguous_and_direct(self):
"""Tensor reads both through contiguous (snapshot) and directly from buffer."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
c = buf[0:4].contiguous()
t = c + buf[4:8]
buf[0:4].assign(Tensor([9, 10, 11, 12]))
Tensor.realize(buf, c, t)
np.testing.assert_equal(c.numpy(), [1, 2, 3, 4])
np.testing.assert_equal(t.numpy(), [1+5, 2+6, 3+7, 4+8])

def test_reduction_after_partial_assign(self):
"""Reduction over buffer after partial assign - must see the assigned values."""
Expand Down
18 changes: 17 additions & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,23 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
assign_uop = self.uop.assign(x.uop)
base = self.uop.base
if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity():
_apply_map_to_tensors({base: base.after(assign_uop)}, name="Embed View Assign", walk=True)
# CONTIGUOUS is a copy barrier — reads must complete before assign writes.
# pre-compute substituted→original mapping for all CONTIGUOUS UOps reading this buffer
buf = base.buf_uop
orig_map = {base: base.after(assign_uop)}
reverse_map: dict[UOp, UOp] = {}
seen: set[int] = set()
for tref in list(all_tensors):
t = tref()
if t is None: continue
for u in t.uop.backward_slice_with_self:
if u.op is Ops.CONTIGUOUS and id(u) not in seen:
seen.add(id(u))
if buf in u.backward_slice:
nc = u.substitute(orig_map, walk=True)
if nc is not u: reverse_map[nc] = u
_apply_map_to_tensors(orig_map, name="Embed View Assign", walk=True)
if reverse_map: _apply_map_to_tensors(reverse_map, name="Restore CONTIGUOUS readers", walk=True)
return self.replace(self._apply_uop(lambda *_: assign_uop, x))

def detach(self) -> Tensor:
Expand Down
Loading