Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/null/test_uop_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,12 @@ def test_div_mod_recompose_low_order_remainder(self):
x = Variable("x", 0, 127)
self.helper_test_variable((x//2)%4*2 + x%2, 0, 7, "(x%8)")

def test_div_mod_recompose_low_order_remainder_scaled(self):
x = Variable("x", 0, 127)
self.helper_test_variable((x//2)%4*6 + (x%2)*3, 0, 21, "(x%8*3)")
y = Variable("y", 0, 59)
self.helper_test_variable((y//3)%4*15 + (y%3)*5, 0, 55, "(y%12*5)")

def test_reshape_index_roundtrip(self):
# simulate reshape index decompose then recompose — the core pattern this enables
# (8,8) decomposed for (16,4): combined=r0*8+r1, div and mod by 4
Expand Down
12 changes: 6 additions & 6 deletions tinygrad/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
from tinygrad.uop.symbolic import sym, sym_bpm, propagate_invalid, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_float_decomp, pm_long_decomp
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
Expand Down Expand Up @@ -45,7 +45,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
if IMAGE == 1 and ren.device in {"QCOM", "CL"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True)

# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic")
sink = graph_rewrite(sink, sym+pm_flatten_range, bpm=sym_bpm, name="initial symbolic")

# optimize (schedule) the AST
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
Expand All @@ -54,10 +54,10 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
sink = apply_opts(sink, ren)

# ** expander (expand_rewrite) **
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
sink = graph_rewrite(sink, sym+pm_move_where_on_load, bpm=sym_bpm, name="postopt symbolic")

# expand
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, bpm=sym_bpm, name="expander")

# add locals
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
Expand All @@ -78,11 +78,11 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, bpm=sym_bpm, name="devectorize")

# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, ctx=ren.device, name="lower all index dtypes")
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
sink = graph_rewrite(sink, symbolic, bpm=propagate_invalid, name="post index symbolic")

# optional pre matcher
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/codegen/simplify.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start
from tinygrad.uop.symbolic import symbolic
from tinygrad.uop.symbolic import symbolic, propagate_invalid
from tinygrad.helpers import partition
from tinygrad.dtype import dtypes, ImageDType

Expand Down Expand Up @@ -28,7 +28,7 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None:
s0, s1 = r0.src[0], r1.src[0]
# do the merge
new_range = r0.replace(src=(s0*s1,))
nidx = graph_rewrite(u, _substitute+symbolic+pm_flatten_range, ctx={r0:new_range//s1, r1:new_range%s1},
nidx = graph_rewrite(u, _substitute+symbolic+pm_flatten_range, ctx={r0:new_range//s1, r1:new_range%s1}, bpm=propagate_invalid,
name=f"check_merge_{r0.arg[0]}_{r1.arg[0]}")

# check if it simplifies
Expand Down
6 changes: 3 additions & 3 deletions tinygrad/schedule/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.uop.symbolic import symbolic, propagate_invalid, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored

ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Expand Down Expand Up @@ -137,7 +137,7 @@ def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:U
axes_out.append(combined_axes % s)
combined_axes //= s
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
return graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, name="reshape")
return graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, bpm=propagate_invalid, name="reshape")

# this is the definition of the movement ops
@functools.cache
Expand Down Expand Up @@ -225,7 +225,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if all_all_same or (PCONTIG and all_same(local_rngs)):
# the new valid is the OR of all the children valids
minimum_valid = UOp.const(dtypes.bool, False).sum(*valids)
_out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid"))
_out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, bpm=propagate_invalid, name="minimum_valid"))
else:
_out_rngs.append(rctx.new_range(x.shape[i]))
_realize_axis.append(i)
Expand Down
5 changes: 3 additions & 2 deletions tinygrad/schedule/rangeify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace, Invalid
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element
from tinygrad.uop.symbolic import symbolic
from tinygrad.uop.symbolic import symbolic, propagate_invalid
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
Expand Down Expand Up @@ -548,7 +548,8 @@ def get_kernel_graph(sink:UOp) -> UOp:
# convert movement ops to ranges
tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY))

tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf")
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, bpm=propagate_invalid,
name="symbolic+reduce_collapse+debuf")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")

if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Rangeify")
Expand Down
9 changes: 5 additions & 4 deletions tinygrad/uop/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,9 @@ def ranges(self) -> dict[UOp, None]:
def simplify(self, tracked=False):
if self.op in {Ops.CONST, Ops.VCONST}: return self
# late import!
from tinygrad.uop.symbolic import symbolic
from tinygrad.uop.symbolic import symbolic, propagate_invalid
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
return graph_rewrite(self, symbolic, name="simplify")
return graph_rewrite(self, symbolic, bpm=propagate_invalid, name="simplify")
def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
def sintify(self) -> sint: return self.arg if self.op is Ops.CONST else self
def _eval(self, dtype, expected_type:Type[T]) -> T:
Expand Down Expand Up @@ -666,8 +666,9 @@ def buf_uop(self) -> UOp:
def contiguous_view_offset(self) -> int|None:
"""If movement ops on a BUFFER collapse to a contiguous range, return `offset` in elements. Otherwise None."""
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.uop.symbolic import symbolic
out = graph_rewrite(self._mop(Ops.RESHAPE, (self.size,)).index(UOp.range(self.size, 0)), pm_mops+symbolic, name="contiguous_view_offset")
from tinygrad.uop.symbolic import symbolic, propagate_invalid
out = graph_rewrite(self._mop(Ops.RESHAPE, (self.size,)).index(UOp.range(self.size, 0)), pm_mops+symbolic, bpm=propagate_invalid,
name="contiguous_view_offset")
if out.op is not Ops.INDEX: return None
if out.src[1].op is Ops.CONST and self.size == 1:
if not isinstance(out.src[1].arg, int): return None # masked/padded regions produce InvalidType
Expand Down
9 changes: 5 additions & 4 deletions tinygrad/uop/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None:
if not exact and base.op is Ops.IDIV and base.src[1].op is Ops.CONST:
exact = q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div
if exact: return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base*mul)
# ((base//div)%d)*div + base%div -> base%(div*d)
if mul == 1 and div > 0 and q.op is Ops.MOD and q.src[1].op is Ops.CONST and (d:=q.src[1].arg) > 0 and q.src[0].op is Ops.IDIV:
# ((base//div)%d)*(div*mul) + (base%div)*mul -> (base%(div*d))*mul
if div > 0 and q.op is Ops.MOD and q.src[1].op is Ops.CONST and (d:=q.src[1].arg) > 0 and q.src[0].op is Ops.IDIV:
if q.src[0].src[0] is base and q.src[0].src[1].op is Ops.CONST and q.src[0].src[1].arg == div:
return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base % (div*d))
return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base % (div*d) * mul)
return None

# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
Expand All @@ -69,7 +69,7 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None:
(UPat(Ops.BITCAST, src=(invalid_gate,), name="bc"), lambda bc,cond,x,i: cond.where(x.bitcast(bc.dtype), i.bitcast(bc.dtype))),
])

symbolic_simple = propagate_invalid + PatternMatcher([
symbolic_simple = PatternMatcher([
# ** self folding **
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
Expand Down Expand Up @@ -415,6 +415,7 @@ def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None:

# this is symbolic 2.0
REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP, Ops.VECTORIZE, Ops.SINK}
sym_bpm = propagate_invalid+PatternMatcher([(invalid_gate, gated_given_valid)])
sym = symbolic+pm_simplify_valid+PatternMatcher([
# reorder ALU/VECTORIZE
(UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
Expand Down
Loading