Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions test/backend/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,27 @@ def test_no_requires_grad_works(self):
def test_set_into_requires_grad(self):
z = Tensor.rand(8, 8, requires_grad=True)
x = Tensor.rand(8)
with self.assertRaises(NotImplementedError):
z[:3] = x
z[:3] = x
loss = z.sum()
loss.backward()
np.testing.assert_allclose(z.grad.numpy(), np.ones((8, 8)))

def test_set_with_requires_grad(self):
z = Tensor.rand(8, 8)
x = Tensor.rand(8, requires_grad=True)
with self.assertRaises(NotImplementedError):
z[:3] = x
z[:3] = x
loss = z.sum()
loss.backward()
np.testing.assert_allclose(x.grad.numpy(), np.full(8, 3.0))

def test_set_both_requires_grad(self):
z = Tensor.rand(8, 8, requires_grad=True)
x = Tensor.rand(8, requires_grad=True)
z[:3] = x
loss = z.sum()
loss.backward()
np.testing.assert_allclose(z.grad.numpy(), np.ones((8, 8)))
np.testing.assert_allclose(x.grad.numpy(), np.full(8, 3.0))

class TestSetitemLoop(unittest.TestCase):
def test_arange(self):
Expand Down
3 changes: 0 additions & 3 deletions test/unit/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ def test_index(self):
def delitem(): del reference[0]
self.assertRaises(TypeError, delitem)

# TODO setitem backward
'''
def test_set_item_to_scalar_tensor(self):
m = random.randint(1, 10)
n = random.randint(1, 10)
Expand All @@ -190,7 +188,6 @@ def test_set_item_to_scalar_tensor(self):
z[:, 0] = w
z.sum().backward()
numpy_testing_assert_equal_helper(w.grad, m * a)
'''

def test_step(self):
v = Tensor.arange(10)
Expand Down
4 changes: 3 additions & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,9 @@ def __getitem__(self, indices) -> Tensor:

def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad): raise NotImplementedError("setitem with requires_grad is not supported")
if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad): # functional setitem for gradient tracking
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
return self.replace(self._getitem(indices, v))
idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices)
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
if any(isinstance(i, (Tensor, list, tuple)) for i in idx): # advanced setitem
Expand Down
Loading