Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/backend/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def test_pad_reduce_unsafe_multiview_st(self):
p = P[0]
p = p.pad(((1, 0), ))
p = p.repeat([2])
run_schedule(check_schedule(p, 4)) # TODO: this is high
run_schedule(check_schedule(p, 3))
tiny_ret = p.numpy()

P = np.ones((3, 3), dtype=np.float32)
Expand Down Expand Up @@ -1036,7 +1036,7 @@ def test_no_extra_contiguous_on_setitem_assign_back(self):
idx = Tensor([1,2,5,6], dtype=dtypes.int32)
flat_base[idx] = Tensor([99,99,99,99])
base.assign(flat_base.reshape(4, 4))
sched = check_schedule(base, 6) # TODO: this is high
sched = check_schedule(base, 4) # TODO: this is high
run_schedule(sched)
expected = list(range(16))
for i, v in zip([1,2,5,6], [99,99,99,99]): expected[i] = v
Expand Down
2 changes: 1 addition & 1 deletion test/null/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def test_dedup_assign(self):
b = Tensor.full((4,), 2.).contiguous()
first = a.assign(b)
second = a.assign(b)
check_schedule([first, second], 2) # TODO: 1?
check_schedule([first, second], 3) # TODO: 1?2?

def test_no_dedup_empty(self):
a = Tensor.empty((4,))
Expand Down
10 changes: 0 additions & 10 deletions tinygrad/schedule/rangeify.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
if any(s.op in unsafe and target.base in s.backward_slice_with_self for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS)):
return assign.replace(src=(target, src.contiguous()))

def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
root_target = target
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
# when RHS depends on the previous assign result, break with contiguous
if target in src.toposort(): src = src.contiguous()
return assign.replace(src=(root_target, src))

def split_reduceop(reduce:UOp, x:UOp):
if prod(reduce.shape) == 0: return None
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
Expand Down Expand Up @@ -124,9 +117,6 @@ def resolve_call(c:UOp) -> UOp|None:
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))),
lambda target, src: target.assign(src.bitcast(target.dtype))),

# if assign target is itself an ASSIGN chain, canonicalize to the original buffer target
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),

# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
])
Expand Down
Loading