From 45681265bc1b37082f0b13bbcbae214003647650 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Tue, 3 Mar 2026 11:08:05 -0500 Subject: [PATCH] fix assign copy barrier --- test/unit/test_assign.py | 88 ++++++++++++++++++++++++++++++++++------ tinygrad/tensor.py | 12 +++++- 2 files changed, 87 insertions(+), 13 deletions(-) diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index a6609895a2340..e68f380130611 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import unittest import numpy as np -from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable +from tinygrad import dtypes, Device, Tensor, TinyJit, GlobalCounters, Variable from tinygrad.uop.ops import Ops from tinygrad.device import is_dtype_supported from tinygrad.helpers import temp, CI, CPU_LVP, Context @@ -685,22 +685,86 @@ def test_chained_slice_copies(self): def test_swap_slices(self): """Swap two non-overlapping slices - requires reading both before writing.""" - # without .realize() on temps: values not captured before overwriting + # single realize: clone reads must complete before assign writes buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() - left = buf[0:4].clone() # lazy - not captured yet - right = buf[4:8].clone() # lazy - not captured yet - buf[0:4].assign(right).realize() # this works - buf[4:8].assign(left).realize() # left now reads from modified buf! - np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 5, 6, 7, 8]) # TODO: wrong! should be [5,6,7,8,1,2,3,4] + left = buf[0:4].clone() + right = buf[4:8].clone() + buf[0:4].assign(right) + buf[4:8].assign(left) + buf.realize() + np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4]) - # with .realize() on temps: values captured before writes + def test_derived_from_clone_before_assign(self): + """Tensors derived from a clone should see pre-assign data, same as the clone itself.""" + buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() + left = buf[0:4].clone() + derived = left + 10 + right = buf[4:8].clone() + buf[0:4].assign(right) + Tensor.realize(buf, left, derived) + np.testing.assert_equal(left.numpy(), [1, 2, 3, 4]) + np.testing.assert_equal(derived.numpy(), [11, 12, 13, 14]) + + def test_view_of_clone_before_assign(self): + """A view of a clone should see the same pre-assign data as the clone.""" + buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() + c = buf[0:4].clone() + v = c[0:2] + right = buf[4:8].clone() + buf[0:4].assign(right) + Tensor.realize(buf, c, v) + np.testing.assert_equal(c.numpy(), [1, 2, 3, 4]) + np.testing.assert_equal(v.numpy(), [1, 2]) + + def test_gc_clone_derived(self): + """GC'd intermediate: no named reference to the clone tensor, but its UOp still exists in derived's graph.""" + buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() + derived = buf[0:4].clone() + 10 + right = buf[4:8].clone() + buf[0:4].assign(right) + Tensor.realize(buf, derived) + np.testing.assert_equal(derived.numpy(), [11, 12, 13, 14]) + + def test_gc_clone_reshaped(self): + """GC'd intermediate: clone().reshape() chain where the clone tensor is garbage collected.""" buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() - left = buf[0:4].clone().realize() - right = buf[4:8].clone().realize() - buf[0:4].assign(right).realize() - buf[4:8].assign(left).realize() + v = buf[0:4].clone().reshape(2, 2) + right = buf[4:8].clone() + buf[0:4].assign(right) + Tensor.realize(buf, v) + np.testing.assert_equal(v.numpy(), [[1, 2], [3, 4]]) + + def test_mixed_clone_and_direct(self): + """Tensor reads both through clone (snapshot) and directly from buffer.""" + buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() + c = buf[0:4].clone() + t = c + buf[4:8] + buf[0:4].assign(Tensor([9, 10, 11, 12])) + Tensor.realize(buf, c, t) + np.testing.assert_equal(c.numpy(), [1, 2, 3, 4]) + np.testing.assert_equal(t.numpy(), [1+5, 2+6, 3+7, 4+8]) + + def test_copy_swap_slices(self): + """Swap via cross-device copy — copy to another device is a copy barrier like clone.""" + dev0, dev1 = Device.DEFAULT, f"{Device.DEFAULT}:1" if ":" not in Device.DEFAULT else Device.DEFAULT.rsplit(":",1)[0]+":1" + buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() + left, right = buf[0:4].to(dev1), buf[4:8].to(dev1) + buf[0:4].assign(right.to(dev0)) + buf[4:8].assign(left.to(dev0)) + buf.realize() np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4]) + def test_copy_before_assign(self): + """Copy to another device before assign should see pre-assign data.""" + dev1 = f"{Device.DEFAULT}:1" if ":" not in Device.DEFAULT else Device.DEFAULT.rsplit(":",1)[0]+":1" + buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() + c = buf[0:4].to(dev1) + derived = c + 10 + buf[0:4].assign(Tensor([9, 10, 11, 12])) + Tensor.realize(buf, c, derived) + np.testing.assert_equal(c.numpy(), [1, 2, 3, 4]) + np.testing.assert_equal(derived.numpy(), [11, 12, 13, 14]) + def test_reduction_after_partial_assign(self): """Reduction over buffer after partial assign - must see the assigned values.""" buf = Tensor.zeros(4, 4).contiguous().realize() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fe9f1a8e93c3f..457ae6ed5913c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -310,7 +310,17 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor: assign_uop = self.uop.assign(x.uop) base = self.uop.base if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity(): - _apply_map_to_tensors({base: base.after(assign_uop)}, name="Embed View Assign", walk=True) + # copy barriers (clone/copy to a different buffer) must read pre-assign data + buf, orig_map = base.buf_uop, {base: base.after(assign_uop)} + reverse_map: dict[UOp, UOp] = {} + for tref in list(all_tensors): + if (t:=tref()) is None: continue + for u in t.uop.backward_slice_with_self: + is_barrier = (u.op is Ops.ASSIGN and u.src[0].buf_uop is not buf) or u.op is Ops.COPY + if is_barrier and buf in u.backward_slice: + if (nc:=u.substitute(orig_map, walk=True)) is not u: reverse_map[nc] = u + _apply_map_to_tensors(orig_map, name="Embed View Assign", walk=True) + if reverse_map: _apply_map_to_tensors(reverse_map, name="Restore copy barrier readers", walk=True) return self.replace(self._apply_uop(lambda *_: assign_uop, x)) def detach(self) -> Tensor: