Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions test/backend/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,17 +293,43 @@ def test_no_requires_grad_works(self):
x = Tensor.rand(8)
z[:3] = x

def test_set_into_requires_grad(self):
z = Tensor.rand(8, 8, requires_grad=True)
x = Tensor.rand(8)
with self.assertRaises(NotImplementedError):
z[:3] = x

def test_set_with_requires_grad(self):
z = Tensor.rand(8, 8)
x = Tensor.rand(8, requires_grad=True)
with self.assertRaises(NotImplementedError):
z[:3] = x
z = Tensor.ones(8, 8)
x = Tensor.rand(8, 8, requires_grad=True)
z[:] = x
z.sum().backward()
np.testing.assert_allclose(x.grad.numpy(), np.ones((8, 8)))

def test_set_nonleaf_requires_grad(self):
x = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
z = x * 2
z[:2] = Tensor([10.0, 20.0])
z.sum().backward()
np.testing.assert_allclose(x.grad.numpy(), [0, 0, 2, 2])

def test_set_overlapping_requires_grad(self):
z = Tensor.zeros(6, requires_grad=True)
x = Tensor.ones(4, requires_grad=True)
y = Tensor.ones(4, requires_grad=True) * 2
z[:4] = x
z[2:] = y
z.sum().backward()
np.testing.assert_allclose(x.grad.numpy(), [1, 1, 0, 0])
np.testing.assert_allclose(y.grad.numpy(), np.ones(4))

def test_set_iadd_requires_grad(self):
z = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
x = Tensor([10.0, 20.0], requires_grad=True)
z[:2] += x
z.sum().backward()
np.testing.assert_allclose(z.grad.numpy(), np.ones(4))
np.testing.assert_allclose(x.grad.numpy(), np.ones(2))

def test_set_used_before_setitem(self):
z = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
_ = z.sum()
with self.assertRaises(RuntimeError):
z[:2] = Tensor([0.0, 0.0])

class TestSetitemLoop(unittest.TestCase):
def test_arange(self):
Expand Down
3 changes: 0 additions & 3 deletions test/unit/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ def test_index(self):
def delitem(): del reference[0]
self.assertRaises(TypeError, delitem)

# TODO setitem backward
'''
def test_set_item_to_scalar_tensor(self):
m = random.randint(1, 10)
n = random.randint(1, 10)
Expand All @@ -190,7 +188,6 @@ def test_set_item_to_scalar_tensor(self):
z[:, 0] = w
z.sum().backward()
numpy_testing_assert_equal_helper(w.grad, m * a)
'''

def test_step(self):
v = Tensor.arange(10)
Expand Down
11 changes: 10 additions & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,16 @@ def __getitem__(self, indices) -> Tensor:

def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad): raise NotImplementedError("setitem with requires_grad is not supported")
if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad):
# for +=/-=, v's graph references self.uop through the view — exclude those from the stale-use check
v_uop, v_bw = (v.uop, v.uop.backward_slice) if isinstance(v, Tensor) else (None, {})
if any(self.uop in t.uop.backward_slice for tref in all_tensors
if (t:=tref()) is not None and t is not self and t.uop is not v_uop and t.uop not in v_bw):
raise RuntimeError("can't setitem on a tensor that already has other uses and requires grad")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
if v.uop.op is Ops.ASSIGN: v = v._apply_uop(lambda x: x.src[1])
self.replace(self._getitem(indices, v))
return
idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices)
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
if any(isinstance(i, (Tensor, list, tuple)) for i in idx): # advanced setitem
Expand Down
Loading