diff --git a/test/null/test_simplify_valid_idx.py b/test/null/test_simplify_valid_idx.py index d01949432b95a..a85f4304472d1 100644 --- a/test/null/test_simplify_valid_idx.py +++ b/test/null/test_simplify_valid_idx.py @@ -390,16 +390,16 @@ def test_simplify4(self): # TODO: can this be simplified further? load = get_load_image_uop(shape, alu9, (((alu8+(alu2*8))%64),(alu2//8))) - self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+8)%64)", "((idx0%8)//2)") + self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+8)%64)", "(idx0//2%4)") load = get_load_image_uop(shape, alu9, (((alu8+(alu3*8))%64),(alu3//8))) - self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+16)%64)", "((idx0%8)//2)") + self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+16)%64)", "(idx0//2%4)") load = get_load_image_uop(shape, alu9, (((alu8+(alu4*8))%64),(alu4//8))) - self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+24)%64)", "((idx0%8)//2)") + self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+24)%64)", "(idx0//2%4)") load = get_load_image_uop(shape, alu9, (((alu8+(alu5*8))%64),(alu5//8))) - self.check(load, "(idx0<256)", "((((idx0%8)*32)+(idx0//32))%64)", "((idx0%8)//2)") + self.check(load, "(idx0<256)", "((((idx0%8)*32)+(idx0//32))%64)", "(idx0//2%4)") def test_simplify5(self): # openpilot 0.9.7, chunk replacement to simplify @@ -414,7 +414,7 @@ def test_simplify5(self): valid = alu3<640 load = get_load_image_uop(shape, valid, idx) - self.check(load, "(((idx0+(idx1*64))%192)<160)", "((idx0+((idx1//3)*16))+128)", "(((idx0+(idx1*64))%192)//16)") + self.check(load, "(((idx0+(idx1*64))%192)<160)", "((idx0+((idx1//3)*16))+128)", "((idx1%3)*4)") def test_simplify6(self): # from openpilot diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index f07371803e947..488d016b98434 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -548,6 +548,13 @@ def test_nest_div_negative_factor(self): def test_div_into_mod(self): self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)") + def test_mod_div_reorder(self): + # (x % (a*b)) // a -> (x // a) % b, enables div-mod recombine + x = Variable("x", 0, 23) + self.helper_test_variable(x % 6 // 3, 0, 1, "(x//3%2)") + self.helper_test_variable(x % 12 // 4, 0, 2, "(x//4%3)") + self.helper_test_variable(x%12//4*4 + x%4 + x//12*12, 0, 23, "x") + def test_div_neg_cancel(self): self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((idx//4)+1)") self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((idx+3)//4)") diff --git a/tinygrad/uop/divandmod.py b/tinygrad/uop/divandmod.py index d52673285fce5..8be86b73ec226 100644 --- a/tinygrad/uop/divandmod.py +++ b/tinygrad/uop/divandmod.py @@ -22,6 +22,10 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: # ** Constant Denominator Rules ** # these rules strictly require y to be a scalar constant > 0 if y.op is Ops.CONST and (c := y.arg) > 0: + # canonicalize_mod_div: (x%(d*k))//d -> (x//d)%k, puts nested div/mod in div-first canonical form for recombine + if d.op is Ops.IDIV and x.op is Ops.MOD and x.src[1].op is Ops.CONST and x.vmin >= 0 and x.src[1].arg % c == 0: + return x.src[0] // y % x.ufix(x.src[1].arg // c) + # remove_nested_mod: remove nested mod in case the inner mod is a multiple of the outer mod, example: (a%4 + b)%2 -> (a+b)%2 if d.op is Ops.MOD and x.vmin >= 0: new_xs, changed = [], False