diff --git a/test/backend/test_setitem.py b/test/backend/test_setitem.py index e4d43463a0210..406cd79460ea6 100644 --- a/test/backend/test_setitem.py +++ b/test/backend/test_setitem.py @@ -293,17 +293,43 @@ def test_no_requires_grad_works(self): x = Tensor.rand(8) z[:3] = x - def test_set_into_requires_grad(self): - z = Tensor.rand(8, 8, requires_grad=True) - x = Tensor.rand(8) - with self.assertRaises(NotImplementedError): - z[:3] = x - def test_set_with_requires_grad(self): - z = Tensor.rand(8, 8) - x = Tensor.rand(8, requires_grad=True) - with self.assertRaises(NotImplementedError): - z[:3] = x + z = Tensor.ones(8, 8) + x = Tensor.rand(8, 8, requires_grad=True) + z[:] = x + z.sum().backward() + np.testing.assert_allclose(x.grad.numpy(), np.ones((8, 8))) + + def test_set_nonleaf_requires_grad(self): + x = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True) + z = x * 2 + z[:2] = Tensor([10.0, 20.0]) + z.sum().backward() + np.testing.assert_allclose(x.grad.numpy(), [0, 0, 2, 2]) + + def test_set_overlapping_requires_grad(self): + z = Tensor.zeros(6, requires_grad=True) + x = Tensor.ones(4, requires_grad=True) + y = Tensor.ones(4, requires_grad=True) * 2 + z[:4] = x + z[2:] = y + z.sum().backward() + np.testing.assert_allclose(x.grad.numpy(), [1, 1, 0, 0]) + np.testing.assert_allclose(y.grad.numpy(), np.ones(4)) + + def test_set_iadd_requires_grad(self): + z = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True) + x = Tensor([10.0, 20.0], requires_grad=True) + z[:2] += x + z.sum().backward() + np.testing.assert_allclose(z.grad.numpy(), np.ones(4)) + np.testing.assert_allclose(x.grad.numpy(), np.ones(2)) + + def test_set_used_before_setitem(self): + z = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True) + _ = z.sum() + with self.assertRaises(RuntimeError): + z[:2] = Tensor([0.0, 0.0]) class TestSetitemLoop(unittest.TestCase): def test_arange(self): diff --git a/test/unit/test_indexing.py b/test/unit/test_indexing.py index 31ab4ff124d03..8e195f0f1d26e 100644 --- a/test/unit/test_indexing.py +++ b/test/unit/test_indexing.py @@ -179,8 +179,6 @@ def test_index(self): def delitem(): del reference[0] self.assertRaises(TypeError, delitem) - # TODO setitem backward - ''' def test_set_item_to_scalar_tensor(self): m = random.randint(1, 10) n = random.randint(1, 10) @@ -190,7 +188,6 @@ def test_set_item_to_scalar_tensor(self): z[:, 0] = w z.sum().backward() numpy_testing_assert_equal_helper(w.grad, m * a) - ''' def test_step(self): v = Tensor.arange(10) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8907add7c4d46..605142ed9de73 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1321,7 +1321,16 @@ def __getitem__(self, indices) -> Tensor: def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None: if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}") - if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad): raise NotImplementedError("setitem with requires_grad is not supported") + if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad): + # for +=/-=, v's graph references self.uop through the view — exclude those from the stale-use check + v_uop, v_bw = (v.uop, v.uop.backward_slice) if isinstance(v, Tensor) else (None, {}) + if any(self.uop in t.uop.backward_slice for tref in all_tensors + if (t:=tref()) is not None and t is not self and t.uop is not v_uop and t.uop not in v_bw): + raise RuntimeError("can't setitem on a tensor that already has other uses and requires grad") + if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) + if v.uop.op is Ops.ASSIGN: v = v._apply_uop(lambda x: x.src[1]) + self.replace(self._getitem(indices, v)) + return idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices) is_disk = isinstance(self.device, str) and self.device.startswith("DISK") if any(isinstance(i, (Tensor, list, tuple)) for i in idx): # advanced setitem