From 39db1a25556a1f3e7c02ea3c3f58ca46fe4e92df Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Wed, 4 Mar 2026 18:35:32 -0500 Subject: [PATCH] minor div_and_mod_symbolic cleanups --- test/null/test_uop_symbolic.py | 9 +++++++++ tinygrad/uop/divandmod.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index 7a2e293bac13b..9a22de68f3af8 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -297,6 +297,9 @@ def test_mod_binary_expression(self): self.helper_test_variable((3+Variable("a",0,1))%4, 0, 3, "((a*-3)+3)") self.helper_test_variable((3+Variable("a",4,5))%4, 0, 3, "((a*-3)+15)") + def test_div_binary_expression(self): + self.helper_test_variable((3+Variable("a",0,1))//4, 0, 1, "a") + def test_sum_div_const(self): self.helper_test_variable(usum([Variable("a", 0, 7)*4, uconst(3)]) // 4, 0, 7, "a") @@ -606,6 +609,12 @@ def test_variable_divmod(self): self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)") self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1") + def test_mod_variable_denom_factor_remainder(self): + d = Variable("d", 2, 5) + a = Variable("a", 0, 3) + b = Variable("b", 0, 1) + self.helper_test_variable((d*a+b)%d, 0, 1, "b") + def test_divmod_variable_denom_fold_to_const(self): x = Variable("x", 20, 23) y = Variable("y", 8, 10) diff --git a/tinygrad/uop/divandmod.py b/tinygrad/uop/divandmod.py index ce917e33d611b..d52673285fce5 100644 --- a/tinygrad/uop/divandmod.py +++ b/tinygrad/uop/divandmod.py @@ -95,7 +95,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: div_and_mod_symbolic = PatternMatcher([ # ** 1. Fast Inline Rules ** ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d) - if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d) + if c.vmin>0 and d.vmin>0 and x.vmin>=0 and a.vmin>=0 else None), # (x//c+a)//d -> (x+a*c)//(c*d) (UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None), (UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <= 0 else None), ((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),