From 9e99aff88a15bf44b8f45ccc2a6bdaae14e6bc17 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Sat, 21 Feb 2026 20:10:27 -0500 Subject: [PATCH] remove normalize_assign_target_chain no longer needed for correctness --- test/backend/test_schedule.py | 4 ++-- test/null/test_schedule.py | 2 +- tinygrad/schedule/rangeify.py | 10 ---------- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/test/backend/test_schedule.py b/test/backend/test_schedule.py index 260a2b33d233e..32c8ad3c54cb6 100644 --- a/test/backend/test_schedule.py +++ b/test/backend/test_schedule.py @@ -747,7 +747,7 @@ def test_pad_reduce_unsafe_multiview_st(self): p = P[0] p = p.pad(((1, 0), )) p = p.repeat([2]) - run_schedule(check_schedule(p, 4)) # TODO: this is high + run_schedule(check_schedule(p, 3)) tiny_ret = p.numpy() P = np.ones((3, 3), dtype=np.float32) @@ -1036,7 +1036,7 @@ def test_no_extra_contiguous_on_setitem_assign_back(self): idx = Tensor([1,2,5,6], dtype=dtypes.int32) flat_base[idx] = Tensor([99,99,99,99]) base.assign(flat_base.reshape(4, 4)) - sched = check_schedule(base, 6) # TODO: this is high + sched = check_schedule(base, 4) # TODO: this is high run_schedule(sched) expected = list(range(16)) for i, v in zip([1,2,5,6], [99,99,99,99]): expected[i] = v diff --git a/test/null/test_schedule.py b/test/null/test_schedule.py index fb89f97f7a264..8e8694eeb32bb 100644 --- a/test/null/test_schedule.py +++ b/test/null/test_schedule.py @@ -356,7 +356,7 @@ def test_dedup_assign(self): b = Tensor.full((4,), 2.).contiguous() first = a.assign(b) second = a.assign(b) - check_schedule([first, second], 2) # TODO: 1? + check_schedule([first, second], 3) # TODO: 1?2? def test_no_dedup_empty(self): a = Tensor.empty((4,)) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 5c9b46e4acef0..31e70b228c4af 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -39,13 +39,6 @@ def fix_assign_hazard(assign:UOp, target:UOp, src:UOp): if any(s.op in unsafe and target.base in s.backward_slice_with_self for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS)): return assign.replace(src=(target, src.contiguous())) -def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp): - root_target = target - while root_target.op is Ops.ASSIGN: root_target = root_target.src[0] - # when RHS depends on the previous assign result, break with contiguous - if target in src.toposort(): src = src.contiguous() - return assign.replace(src=(root_target, src)) - def split_reduceop(reduce:UOp, x:UOp): if prod(reduce.shape) == 0: return None if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape)) UOp|None: (UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))), lambda target, src: target.assign(src.bitcast(target.dtype))), - # if assign target is itself an ASSIGN chain, canonicalize to the original buffer target - (UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain), - # make source contiguous if it has hazardous movement ops on the dest buffer (UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard), ])