Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions test/unit/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,17 @@ def f(a:Tensor, b:Tensor) -> Tensor:
np.testing.assert_equal(a.numpy(), [11,21,31]) # TODO: should be [1,2,3]
np.testing.assert_equal(b.numpy(), [10,20,30])

def test_view_assign_explicit_buffer(self):
"""view assign on an explicit param's buffer should not create implicit inputs."""
class State:
def __init__(self): self.buf = Tensor.zeros(2, 4).contiguous().realize()
@function(allow_implicit=False)
def __call__(self, x:Tensor) -> Tensor:
self.buf[:, 0:2].assign(x)
return self.buf[:, 0:2]
s = State()
np.testing.assert_equal(s(Tensor([[5., 6.], [7., 8.]])).numpy(), [[5., 6.], [7., 8.]])

@unittest.expectedFailure
def test_assign_slice(self):
@function
Expand Down
12 changes: 3 additions & 9 deletions tinygrad/apps/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,9 @@ def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
q = apply_rope(q, self.freqs_cis[start_pos:start_pos+T])
k = apply_rope(k, self.freqs_cis[start_pos:start_pos+T])

# TODO: fix assign to behave like this
assigned_kv = self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.assign(Tensor.stack(k, v).contiguous().uop))
tensor_assigned_kv = Tensor(assigned_kv, device=assigned_kv.device)
k = tensor_assigned_kv[0, :, :, 0:start_pos+T, :]
v = tensor_assigned_kv[1, :, :, 0:start_pos+T, :]

#self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
v = self.cache_kv[1, :, :, 0:start_pos+T, :]

# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
# TODO: this if statement should be removed and it shouldn't generate extra kernels
Expand Down
17 changes: 8 additions & 9 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,15 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
store_uop = self.uop.store(x.uop)
base = self.uop.base
if base.op in {Ops.BUFFER, Ops.AFTER} and self.uop is not base and not self.uop.has_buffer_identity():
# view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(base, inner) for dependency
original_uop = self.uop
view_after = self.uop.after(store_uop)
assigned_base = base.after(view_after)
_apply_map_to_tensors({base: assigned_base}, name="Embed View Assign", walk=True)
# view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(ib, inner) for dependency
# replace at the buffer-identity level (e.g. RESHAPE(BUFFER)) so @function's substitution catches it
ib = self.uop
while not ib.has_buffer_identity() and ib is not base: ib = ib.src[0]
assigned_ib = ib.after(self.uop.after(store_uop))
_apply_map_to_tensors({ib: assigned_ib}, name="Embed View Assign", walk=True)
def replace_view_base(u:UOp) -> UOp:
return u.replace(src=((assigned_base if u.src[0] is base else replace_view_base(u.src[0])),)+u.src[1:])
ret = Tensor(replace_view_base(original_uop), device=self.device, requires_grad=self.requires_grad)
self.replace(self._apply_uop(lambda *_: replace_view_base(original_uop), x))
return ret
return u.replace(src=((assigned_ib if u.src[0] is ib else replace_view_base(u.src[0])),)+u.src[1:])
return Tensor(replace_view_base(self.uop), device=self.device, requires_grad=self.requires_grad)
# simple assign: AFTER wraps self.uop (may be RESHAPE'd buffer) with STORE effect
return self.replace(self._apply_uop(lambda *_: self.uop.after(store_uop), x))

Expand Down
Loading