diff --git a/test/unit/test_function.py b/test/unit/test_function.py index 5f0b4b4c1f11f..5774d5bbe14be 100644 --- a/test/unit/test_function.py +++ b/test/unit/test_function.py @@ -213,6 +213,17 @@ def f(a:Tensor, b:Tensor) -> Tensor: np.testing.assert_equal(a.numpy(), [11,21,31]) # TODO: should be [1,2,3] np.testing.assert_equal(b.numpy(), [10,20,30]) + def test_view_assign_explicit_buffer(self): + """view assign on an explicit param's buffer should not create implicit inputs.""" + class State: + def __init__(self): self.buf = Tensor.zeros(2, 4).contiguous().realize() + @function(allow_implicit=False) + def __call__(self, x:Tensor) -> Tensor: + self.buf[:, 0:2].assign(x) + return self.buf[:, 0:2] + s = State() + np.testing.assert_equal(s(Tensor([[5., 6.], [7., 8.]])).numpy(), [[5., 6.], [7., 8.]]) + @unittest.expectedFailure def test_assign_slice(self): @function diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index cc43ef93021e4..704eba4414e8f 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -132,15 +132,9 @@ def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: q = apply_rope(q, self.freqs_cis[start_pos:start_pos+T]) k = apply_rope(k, self.freqs_cis[start_pos:start_pos+T]) - # TODO: fix assign to behave like this - assigned_kv = self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.assign(Tensor.stack(k, v).contiguous().uop)) - tensor_assigned_kv = Tensor(assigned_kv, device=assigned_kv.device) - k = tensor_assigned_kv[0, :, :, 0:start_pos+T, :] - v = tensor_assigned_kv[1, :, :, 0:start_pos+T, :] - - #self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)) - #k = self.cache_kv[0, :, :, 0:start_pos+T, :] - #v = self.cache_kv[1, :, :, 0:start_pos+T, :] + self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)) + k = self.cache_kv[0, :, :, 0:start_pos+T, :] + v = self.cache_kv[1, :, :, 0:start_pos+T, :] # NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True # TODO: this if statement should be removed and it shouldn't generate extra kernels diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 4f76db143328d..051f28ed5a267 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -312,16 +312,15 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor: store_uop = self.uop.store(x.uop) base = self.uop.base if base.op in {Ops.BUFFER, Ops.AFTER} and self.uop is not base and not self.uop.has_buffer_identity(): - # view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(base, inner) for dependency - original_uop = self.uop - view_after = self.uop.after(store_uop) - assigned_base = base.after(view_after) - _apply_map_to_tensors({base: assigned_base}, name="Embed View Assign", walk=True) + # view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(ib, inner) for dependency + # replace at the buffer-identity level (e.g. RESHAPE(BUFFER)) so @function's substitution catches it + ib = self.uop + while not ib.has_buffer_identity() and ib is not base: ib = ib.src[0] + assigned_ib = ib.after(self.uop.after(store_uop)) + _apply_map_to_tensors({ib: assigned_ib}, name="Embed View Assign", walk=True) def replace_view_base(u:UOp) -> UOp: - return u.replace(src=((assigned_base if u.src[0] is base else replace_view_base(u.src[0])),)+u.src[1:]) - ret = Tensor(replace_view_base(original_uop), device=self.device, requires_grad=self.requires_grad) - self.replace(self._apply_uop(lambda *_: replace_view_base(original_uop), x)) - return ret + return u.replace(src=((assigned_ib if u.src[0] is ib else replace_view_base(u.src[0])),)+u.src[1:]) + return Tensor(replace_view_base(self.uop), device=self.device, requires_grad=self.requires_grad) # simple assign: AFTER wraps self.uop (may be RESHAPE'd buffer) with STORE effect return self.replace(self._apply_uop(lambda *_: self.uop.after(store_uop), x))