From 19960a41ac6012cd37e4ec5833d341df2a92bb83 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Fri, 27 Feb 2026 10:31:52 -0500 Subject: [PATCH] setitem backward --- test/backend/test_setitem.py | 21 +++++++++++++++++---- test/unit/test_indexing.py | 3 --- tinygrad/tensor.py | 4 +++- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/test/backend/test_setitem.py b/test/backend/test_setitem.py index 4626d2875b171..32d81425cb61c 100644 --- a/test/backend/test_setitem.py +++ b/test/backend/test_setitem.py @@ -282,14 +282,27 @@ def test_no_requires_grad_works(self): def test_set_into_requires_grad(self): z = Tensor.rand(8, 8, requires_grad=True) x = Tensor.rand(8) - with self.assertRaises(NotImplementedError): - z[:3] = x + z[:3] = x + loss = z.sum() + loss.backward() + np.testing.assert_allclose(z.grad.numpy(), np.ones((8, 8))) def test_set_with_requires_grad(self): z = Tensor.rand(8, 8) x = Tensor.rand(8, requires_grad=True) - with self.assertRaises(NotImplementedError): - z[:3] = x + z[:3] = x + loss = z.sum() + loss.backward() + np.testing.assert_allclose(x.grad.numpy(), np.full(8, 3.0)) + + def test_set_both_requires_grad(self): + z = Tensor.rand(8, 8, requires_grad=True) + x = Tensor.rand(8, requires_grad=True) + z[:3] = x + loss = z.sum() + loss.backward() + np.testing.assert_allclose(z.grad.numpy(), np.ones((8, 8))) + np.testing.assert_allclose(x.grad.numpy(), np.full(8, 3.0)) class TestSetitemLoop(unittest.TestCase): def test_arange(self): diff --git a/test/unit/test_indexing.py b/test/unit/test_indexing.py index 31ab4ff124d03..8e195f0f1d26e 100644 --- a/test/unit/test_indexing.py +++ b/test/unit/test_indexing.py @@ -179,8 +179,6 @@ def test_index(self): def delitem(): del reference[0] self.assertRaises(TypeError, delitem) - # TODO setitem backward - ''' def test_set_item_to_scalar_tensor(self): m = random.randint(1, 10) n = random.randint(1, 10) @@ -190,7 +188,6 @@ def test_set_item_to_scalar_tensor(self): z[:, 0] = w z.sum().backward() numpy_testing_assert_equal_helper(w.grad, m * a) - ''' def test_step(self): v = Tensor.arange(10) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index afb163e225f4c..75c7c8580f0d3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1326,7 +1326,9 @@ def __getitem__(self, indices) -> Tensor: def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None: if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}") - if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad): raise NotImplementedError("setitem with requires_grad is not supported") + if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad): # functional setitem for gradient tracking + if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) + return self.replace(self._getitem(indices, v)) idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices) is_disk = isinstance(self.device, str) and self.device.startswith("DISK") if any(isinstance(i, (Tensor, list, tuple)) for i in idx): # advanced setitem