diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index b78855ad6e59c..9ceb53c59c584 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -685,21 +685,64 @@ def test_chained_slice_copies(self): def test_swap_slices(self): """Swap two non-overlapping slices - requires reading both before writing.""" - # without .realize() on temps: values not captured before overwriting + # single realize: WAR detection orders contiguous reads before assign writes buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() - left = buf[0:4].contiguous() # lazy - not captured yet - right = buf[4:8].contiguous() # lazy - not captured yet - buf[0:4].assign(right).realize() # this works - buf[4:8].assign(left).realize() # left now reads from modified buf! - np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 5, 6, 7, 8]) # TODO: wrong! should be [5,6,7,8,1,2,3,4] + left = buf[0:4].contiguous() + right = buf[4:8].contiguous() + buf[0:4].assign(right) + buf[4:8].assign(left) + buf.realize() + np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4]) - # with .realize() on temps: values captured before writes + def test_derived_from_contiguous_before_assign(self): + """Tensors derived from a contiguous read should see pre-assign data, same as the contiguous itself.""" buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() - left = buf[0:4].contiguous().realize() - right = buf[4:8].contiguous().realize() - buf[0:4].assign(right).realize() - buf[4:8].assign(left).realize() - np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4]) + left = buf[0:4].contiguous() + derived = left + 10 + right = buf[4:8].contiguous() + buf[0:4].assign(right) + Tensor.realize(buf, left, derived) + np.testing.assert_equal(left.numpy(), [1, 2, 3, 4]) + np.testing.assert_equal(derived.numpy(), [11, 12, 13, 14]) + + def test_view_of_contiguous_before_assign(self): + """A view of a contiguous read should see the same pre-assign data as the contiguous.""" + buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() + c = buf[0:4].contiguous() + v = c[0:2] + right = buf[4:8].contiguous() + buf[0:4].assign(right) + Tensor.realize(buf, c, v) + np.testing.assert_equal(c.numpy(), [1, 2, 3, 4]) + np.testing.assert_equal(v.numpy(), [1, 2]) + + def test_gc_contiguous_derived(self): + """GC'd intermediate: no named reference to the contiguous tensor, but its UOp still exists in derived's graph.""" + buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() + derived = buf[0:4].contiguous() + 10 + right = buf[4:8].contiguous() + buf[0:4].assign(right) + Tensor.realize(buf, derived) + np.testing.assert_equal(derived.numpy(), [11, 12, 13, 14]) + + def test_gc_contiguous_reshaped(self): + """GC'd intermediate: contiguous().reshape() chain where the contiguous tensor is garbage collected.""" + buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() + v = buf[0:4].contiguous().reshape(2, 2) + right = buf[4:8].contiguous() + buf[0:4].assign(right) + Tensor.realize(buf, v) + np.testing.assert_equal(v.numpy(), [[1, 2], [3, 4]]) + + def test_mixed_contiguous_and_direct(self): + """Tensor reads both through contiguous (snapshot) and directly from buffer.""" + buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() + c = buf[0:4].contiguous() + t = c + buf[4:8] + buf[0:4].assign(Tensor([9, 10, 11, 12])) + Tensor.realize(buf, c, t) + np.testing.assert_equal(c.numpy(), [1, 2, 3, 4]) + np.testing.assert_equal(t.numpy(), [1+5, 2+6, 3+7, 4+8]) def test_reduction_after_partial_assign(self): """Reduction over buffer after partial assign - must see the assigned values.""" diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fe9f1a8e93c3f..c11fc5bfdb741 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -310,7 +310,23 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor: assign_uop = self.uop.assign(x.uop) base = self.uop.base if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity(): - _apply_map_to_tensors({base: base.after(assign_uop)}, name="Embed View Assign", walk=True) + # CONTIGUOUS is a copy barrier — reads must complete before assign writes. + # pre-compute substituted→original mapping for all CONTIGUOUS UOps reading this buffer + buf = base.buf_uop + orig_map = {base: base.after(assign_uop)} + reverse_map: dict[UOp, UOp] = {} + seen: set[int] = set() + for tref in list(all_tensors): + t = tref() + if t is None: continue + for u in t.uop.backward_slice_with_self: + if u.op is Ops.CONTIGUOUS and id(u) not in seen: + seen.add(id(u)) + if buf in u.backward_slice: + nc = u.substitute(orig_map, walk=True) + if nc is not u: reverse_map[nc] = u + _apply_map_to_tensors(orig_map, name="Embed View Assign", walk=True) + if reverse_map: _apply_map_to_tensors(reverse_map, name="Restore CONTIGUOUS readers", walk=True) return self.replace(self._apply_uop(lambda *_: assign_uop, x)) def detach(self) -> Tensor: