From f33481de66dda36562c8a85b3c6ba40af141115a Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Wed, 11 Mar 2026 13:21:13 +0800 Subject: [PATCH] bpm invalid [pr] --- test/null/test_uop_symbolic.py | 6 ++++++ tinygrad/codegen/__init__.py | 12 ++++++------ tinygrad/codegen/simplify.py | 4 ++-- tinygrad/schedule/indexing.py | 6 +++--- tinygrad/schedule/rangeify.py | 5 +++-- tinygrad/uop/ops.py | 9 +++++---- tinygrad/uop/symbolic.py | 9 +++++---- 7 files changed, 30 insertions(+), 21 deletions(-) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index 423a88ba14b06..e956b557af6cd 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -797,6 +797,12 @@ def test_div_mod_recompose_low_order_remainder(self): x = Variable("x", 0, 127) self.helper_test_variable((x//2)%4*2 + x%2, 0, 7, "(x%8)") + def test_div_mod_recompose_low_order_remainder_scaled(self): + x = Variable("x", 0, 127) + self.helper_test_variable((x//2)%4*6 + (x%2)*3, 0, 21, "(x%8*3)") + y = Variable("y", 0, 59) + self.helper_test_variable((y//3)%4*15 + (y%3)*5, 0, 55, "(y%12*5)") + def test_reshape_index_roundtrip(self): # simulate reshape index decompose then recompose — the core pattern this enables # (8,8) decomposed for (16,4): combined=r0*8+r1, div and mod by 4 diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 78166a2d4680b..d78dc35e133b4 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -12,7 +12,7 @@ # import all pattern matchers here from tinygrad.codegen.gpudims import pm_add_gpudims -from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load +from tinygrad.uop.symbolic import sym, sym_bpm, propagate_invalid, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_float_decomp, pm_long_decomp from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ @@ -45,7 +45,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - if IMAGE == 1 and ren.device in {"QCOM", "CL"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True) # symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct) - sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic") + sink = graph_rewrite(sink, sym+pm_flatten_range, bpm=sym_bpm, name="initial symbolic") # optimize (schedule) the AST sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges") @@ -54,10 +54,10 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - sink = apply_opts(sink, ren) # ** expander (expand_rewrite) ** - sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic") + sink = graph_rewrite(sink, sym+pm_move_where_on_load, bpm=sym_bpm, name="postopt symbolic") # expand - sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander") + sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, bpm=sym_bpm, name="expander") # add locals sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers") @@ -78,11 +78,11 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing - if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize") + if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, bpm=sym_bpm, name="devectorize") # lower the index dtype to a concrete int sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, ctx=ren.device, name="lower all index dtypes") - sink = graph_rewrite(sink, symbolic, name="post index symbolic") + sink = graph_rewrite(sink, symbolic, bpm=propagate_invalid, name="post index symbolic") # optional pre matcher if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher") diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index 1d464b2055779..48b92758f5a19 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -1,6 +1,6 @@ import itertools from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start -from tinygrad.uop.symbolic import symbolic +from tinygrad.uop.symbolic import symbolic, propagate_invalid from tinygrad.helpers import partition from tinygrad.dtype import dtypes, ImageDType @@ -28,7 +28,7 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None: s0, s1 = r0.src[0], r1.src[0] # do the merge new_range = r0.replace(src=(s0*s1,)) - nidx = graph_rewrite(u, _substitute+symbolic+pm_flatten_range, ctx={r0:new_range//s1, r1:new_range%s1}, + nidx = graph_rewrite(u, _substitute+symbolic+pm_flatten_range, ctx={r0:new_range//s1, r1:new_range%s1}, bpm=propagate_invalid, name=f"check_merge_{r0.arg[0]}_{r1.arg[0]}") # check if it simplifies diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 323cbfffabac1..bc78792a31510 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink -from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses +from tinygrad.uop.symbolic import symbolic, propagate_invalid, pm_simplify_valid, pm_drop_and_clauses from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, @@ -137,7 +137,7 @@ def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:U axes_out.append(combined_axes % s) combined_axes //= s # this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code - return graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, name="reshape") + return graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, bpm=propagate_invalid, name="reshape") # this is the definition of the movement ops @functools.cache @@ -225,7 +225,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: if all_all_same or (PCONTIG and all_same(local_rngs)): # the new valid is the OR of all the children valids minimum_valid = UOp.const(dtypes.bool, False).sum(*valids) - _out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid")) + _out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, bpm=propagate_invalid, name="minimum_valid")) else: _out_rngs.append(rctx.new_range(x.shape[i])) _realize_axis.append(i) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 9be14b45739b5..cce4516a13f3f 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -3,7 +3,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace, Invalid from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element -from tinygrad.uop.symbolic import symbolic +from tinygrad.uop.symbolic import symbolic, propagate_invalid from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify @@ -548,7 +548,8 @@ def get_kernel_graph(sink:UOp) -> UOp: # convert movement ops to ranges tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY)) - tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf") + tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, bpm=propagate_invalid, + name="symbolic+reduce_collapse+debuf") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Rangeify") diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 23230d0abbfb9..581ddc346eeb7 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -369,9 +369,9 @@ def ranges(self) -> dict[UOp, None]: def simplify(self, tracked=False): if self.op in {Ops.CONST, Ops.VCONST}: return self # late import! - from tinygrad.uop.symbolic import symbolic + from tinygrad.uop.symbolic import symbolic, propagate_invalid with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value): - return graph_rewrite(self, symbolic, name="simplify") + return graph_rewrite(self, symbolic, bpm=propagate_invalid, name="simplify") def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret def sintify(self) -> sint: return self.arg if self.op is Ops.CONST else self def _eval(self, dtype, expected_type:Type[T]) -> T: @@ -666,8 +666,9 @@ def buf_uop(self) -> UOp: def contiguous_view_offset(self) -> int|None: """If movement ops on a BUFFER collapse to a contiguous range, return `offset` in elements. Otherwise None.""" from tinygrad.schedule.rangeify import pm_mops - from tinygrad.uop.symbolic import symbolic - out = graph_rewrite(self._mop(Ops.RESHAPE, (self.size,)).index(UOp.range(self.size, 0)), pm_mops+symbolic, name="contiguous_view_offset") + from tinygrad.uop.symbolic import symbolic, propagate_invalid + out = graph_rewrite(self._mop(Ops.RESHAPE, (self.size,)).index(UOp.range(self.size, 0)), pm_mops+symbolic, bpm=propagate_invalid, + name="contiguous_view_offset") if out.op is not Ops.INDEX: return None if out.src[1].op is Ops.CONST and self.size == 1: if not isinstance(out.src[1].arg, int): return None # masked/padded regions produce InvalidType diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 092b6b54026a1..cf88f0c9510df 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -42,10 +42,10 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None: if not exact and base.op is Ops.IDIV and base.src[1].op is Ops.CONST: exact = q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div if exact: return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base*mul) - # ((base//div)%d)*div + base%div -> base%(div*d) - if mul == 1 and div > 0 and q.op is Ops.MOD and q.src[1].op is Ops.CONST and (d:=q.src[1].arg) > 0 and q.src[0].op is Ops.IDIV: + # ((base//div)%d)*(div*mul) + (base%div)*mul -> (base%(div*d))*mul + if div > 0 and q.op is Ops.MOD and q.src[1].op is Ops.CONST and (d:=q.src[1].arg) > 0 and q.src[0].op is Ops.IDIV: if q.src[0].src[0] is base and q.src[0].src[1].op is Ops.CONST and q.src[0].src[1].arg == div: - return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base % (div*d)) + return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base % (div*d) * mul) return None # this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0 @@ -69,7 +69,7 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None: (UPat(Ops.BITCAST, src=(invalid_gate,), name="bc"), lambda bc,cond,x,i: cond.where(x.bitcast(bc.dtype), i.bitcast(bc.dtype))), ]) -symbolic_simple = propagate_invalid + PatternMatcher([ +symbolic_simple = PatternMatcher([ # ** self folding ** (UPat.var("x") + 0, lambda x: x), # x+0 -> x (UPat.var("x") * 1, lambda x: x), # x*1 -> x @@ -415,6 +415,7 @@ def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None: # this is symbolic 2.0 REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP, Ops.VECTORIZE, Ops.SINK} +sym_bpm = propagate_invalid+PatternMatcher([(invalid_gate, gated_given_valid)]) sym = symbolic+pm_simplify_valid+PatternMatcher([ # reorder ALU/VECTORIZE (UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),