Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions test/unit/test_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,13 +756,12 @@ def test_interleaved_assign_read_patterns(self):
np.testing.assert_equal(b.numpy(), [1, 2, 3, 4])

def test_variable_slice_ordering(self):
"""Variable-indexed slices - tests symbolic dependency tracking."""
"""Variable-indexed slices with conflicting bindings are not allowed."""
v_i = Variable("i", 0, 3)
buf = Tensor.zeros(4, 4).contiguous().realize()
buf[v_i.bind(0):v_i.bind(0)+1, :].assign(Tensor.ones(1, 4))
buf[v_i.bind(1):v_i.bind(1)+1, :].assign(Tensor.ones(1, 4) * 2)
self.assertEqual(buf[0:1, :].sum().item(), 4)
self.assertEqual(buf[1:2, :].sum().item(), 8)
with self.assertRaises(RuntimeError): buf.sum().realize()

def test_multi_step_assign_read_write_same_buffer(self):
"""Assign to m and param reading b, then update b, across multiple steps.
Expand Down
14 changes: 10 additions & 4 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,27 +268,33 @@ def schedule(self, *lst:Tensor) -> list[ExecItem]:
assert len(var_vals) == 0
return schedule

@disable_gc()
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
"""Triggers the computation needed to create these Tensor(s)."""
# side-realize pending assigns for buffers referenced by these tensors
# collect and schedule pending assigns, then schedule main computation, run everything once
all_schedules: list[ExecItem] = []
all_var_vals: dict[str, int] = {}
if _pending_assigns:
def _realize_pending(buf):
nonlocal all_var_vals
for assign_uop in _pending_assigns.pop(buf, []):
# recursively realize pending assigns that this assign's value depends on
for u in assign_uop.toposort():
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
all_schedules.extend(schedule)
all_var_vals = merge_dicts([all_var_vals, var_vals])
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
if becomes_map:
for assigns in _pending_assigns.values():
for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map)
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
if buf in _pending_assigns: _realize_pending(buf)
if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]):
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
schedule, var_vals = Tensor.schedule_with_vars(*to_realize)
all_schedules.extend(schedule)
all_var_vals = merge_dicts([all_var_vals, var_vals])
if all_schedules: run_schedule(all_schedules, all_var_vals, do_update_stats=do_update_stats)
return self

def replace(self, x:Tensor) -> Tensor:
Expand Down
Loading