From 3707037bd0e865213a05f0d52d8da4413b741132 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Fri, 7 Feb 2025 14:39:48 -0500 Subject: [PATCH 01/24] Fuzz reorder statements with basic constraint checks --- src/exo/rewrite/LoopIR_scheduling.py | 10 +- src/exo/rewrite/chexo.py | 273 +++++++++++++ src/exo/rewrite/constraint_solver.py | 359 ++++++++++++++++++ .../test_make_constraint.txt | 1 + .../test_constraint_solver/test_solve.txt | 1 + tests/test_constraint_solver.py | 36 ++ 6 files changed, 676 insertions(+), 4 deletions(-) create mode 100644 src/exo/rewrite/chexo.py create mode 100644 src/exo/rewrite/constraint_solver.py create mode 100644 tests/golden/test_constraint_solver/test_make_constraint.txt create mode 100644 tests/golden/test_constraint_solver/test_solve.txt create mode 100644 tests/test_constraint_solver.py diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index c67f0ec70..ec3361e18 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -32,6 +32,7 @@ Check_ExprBound, Check_Aliasing, ) +from .chexo import fuzz_reorder_stmts from .range_analysis import IndexRangeEnvironment, IndexRange, index_range_analysis @@ -206,7 +207,7 @@ def _replace_pats(ir, fwd, c, pat, repl, only_replace_attrs=True, use_sym_id=Tru todos.append((rd, c_repl)) cur_fwd = lambda x: x - for (rd, c_repl) in todos: + for rd, c_repl in todos: rd = cur_fwd(rd) ir, fwd_rd = _replace_helper(rd, c_repl, only_replace_attrs) cur_fwd = _compose(fwd_rd, cur_fwd) @@ -222,7 +223,7 @@ def _replace_reads(ir, fwd, c, sym, repl, only_replace_attrs=True): todos.append((rd, c_repl)) cur_fwd = lambda x: x - for (rd, c_repl) in todos: + for rd, c_repl in todos: rd = cur_fwd(rd) ir, fwd_rd = _replace_helper(rd, c_repl, only_replace_attrs) cur_fwd = _compose(fwd_rd, cur_fwd) @@ -249,7 +250,7 @@ def _replace_writes( todos.append((s, c_repl)) cur_fwd = lambda x: x - for (s, c_repl) in todos: + for s, c_repl in todos: s = cur_fwd(s) ir, fwd_s = _replace_helper(s, c_repl, only_replace_attrs) cur_fwd = _compose(fwd_s, cur_fwd) @@ -375,7 +376,8 @@ def DoReorderStmt(f_cursor, s_cursor): raise SchedulingError( "expected the second statement to be directly after the first" ) - Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node) + # Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node) + fuzz_reorder_stmts(f_cursor, s_cursor) ir, fwd = s_cursor._move(f_cursor.before()) return ir, fwd diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py new file mode 100644 index 000000000..171d6eda6 --- /dev/null +++ b/src/exo/rewrite/chexo.py @@ -0,0 +1,273 @@ +from typing import Optional +from ..core.LoopIR import LoopIR, T +from dataclasses import dataclass +from ..core.prelude import Sym, SrcInfo +from ..core.memory import DRAM +from ..backend.LoopIR_interpreter import run_interpreter +import numpy as np +from .new_eff import SchedulingError +from .constraint_solver import ( + ConjunctionConstraint, + ConstraintTerm, + GenericConstraint, + Constraint, + ConstraintMaker, +) + + +class LoopIRVisitor: + def visit(self, node): + self.visit_generic(node) + + def visit_generic(self, node): + if ( + isinstance(node, LoopIR.proc) + or isinstance(node, LoopIR.instr) + or isinstance(node, LoopIR.fnarg) + or isinstance(node, LoopIR.stmt) + or isinstance(node, LoopIR.loop_mode) + or isinstance(node, LoopIR.expr) + or isinstance(node, LoopIR.w_access) + or isinstance(node, LoopIR.type) + ): + for field_name in dir(node): + if not field_name.startswith("_"): + field = getattr(node, field_name) + if isinstance(field, list): + for child in field: + self.visit(child) + else: + self.visit(field) + + +@dataclass +class TypeVisitor(LoopIRVisitor): + type_map: dict[Sym, LoopIR.type] + + def visit(self, node): + if isinstance(node, LoopIR.For): + self.type_map[node.iter] = T.Index() + self.visit_generic(node) + elif isinstance(node, LoopIR.Alloc): + self.type_map[node.name] = node.type + elif isinstance(node, LoopIR.WindowStmt): + self.type_map[node.name] = node.rhs.type + elif isinstance(node, LoopIR.fnarg): + self.type_map[node.name] = node.type + else: + self.visit_generic(node) + + +@dataclass +class UsedVariableVisitor(LoopIRVisitor): + used_vars: set[Sym] + + def visit(self, node): + if isinstance(node, Sym): + self.used_vars.add(node) + else: + self.visit_generic(node) + + +def get_free_variables(type_map, fragment): + fragment_type_visitor = TypeVisitor({}) + fragment_var_visitor = UsedVariableVisitor(set()) + for stmt in fragment: + fragment_type_visitor.visit(stmt) + fragment_var_visitor.visit(stmt) + for var in fragment_var_visitor.used_vars - fragment_type_visitor.type_map.keys(): + fragment_var_visitor.visit(type_map[var]) + return { + var: type_map[var] + for var in fragment_var_visitor.used_vars + - fragment_type_visitor.type_map.keys() + } + + +def eval_tensor_dimension(dim_expr, control_values): + if isinstance(dim_expr, LoopIR.Read): + return control_values[dim_expr.name] + elif isinstance(dim_expr, LoopIR.Const): + return dim_expr.val + elif isinstance(dim_expr, LoopIR.USub): + return -eval_tensor_dimension(dim_expr.arg) + elif isinstance(dim_expr, LoopIR.BinOp): + lhs, rhs = eval_tensor_dimension(dim_expr.lhs), eval_tensor_dimension( + dim_expr.rhs + ) + if dim_expr.op == "+": + return lhs + rhs + elif dim_expr.op == "-": + return lhs - rhs + elif dim_expr.op == "*": + return lhs * rhs + elif dim_expr.op == "/": + if isinstance(lhs, int) and isinstance(rhs, int): + # this is what was here before and without the rhs check + # counter example of why this is wrong -3 / 2 == -1 in C and 0 in this impl + # return (lhs + rhs - 1) // rhs + return int(lhs / rhs) + else: + return lhs / rhs + elif dim_expr.op == "%": + return lhs % rhs + elif dim_expr.op == "==": + return lhs == rhs + elif dim_expr.op == "<": + return lhs < rhs + elif dim_expr.op == ">": + return lhs > rhs + elif dim_expr.op == "<=": + return lhs <= rhs + elif dim_expr.op == ">=": + return lhs >= rhs + elif dim_expr.op == "and": + return lhs and rhs + elif dim_expr.op == "or": + return lhs or rhs + else: + assert False, "unexpected expression type in tensor dimension" + + +CONTROL_VAL_BOUND = 16 +INT_BOUND = 128 +FLOAT_BOUND = 32 + + +def collect_path_constraints(cursor, cm: ConstraintMaker) -> GenericConstraint: + cur = cursor + result = Constraint(()) + while cur.depth() != 0: + if isinstance(cur._node, LoopIR.For): + result = ConjunctionConstraint( + ConjunctionConstraint( + result, + Constraint( + (ConstraintTerm(1, (cur._node.iter,)),) + + tuple( + term.negate() + for term in cm.make_constraint_terms(cur._node.lo) + ) + ), + ), + Constraint( + (ConstraintTerm(-1, (cur._node.iter,)),) + + cm.make_constraint_terms(cur._node.hi) + + (ConstraintTerm(-1, ()),) + ), + ) + elif isinstance(cur._node, LoopIR.If): + result = ConjunctionConstraint(result, cm.make_constraint(cur._node.cond)) + cur = cur.parent() + return result + + +def generate_args(args, constraint: Constraint, cm: ConstraintMaker): + arg_values = {} + control_values = {} + assignments = cm.solve_constraint(constraint, CONTROL_VAL_BOUND) + for arg in args: + if not arg.type.is_numeric(): + if arg.name in assignments: + val = assignments[arg.name] + elif isinstance(arg.type, T.Bool): + val = np.random.randint(0, CONTROL_VAL_BOUND) < CONTROL_VAL_BOUND / 2 + else: + val = np.random.randint(0, CONTROL_VAL_BOUND) + control_values[arg.name] = val + arg_values[str(arg.name)] = val + + for arg in args: + if arg.type.is_numeric(): + basetype = arg.type.basetype() + if isinstance(basetype, (T.F32, T.Num)): + dtype = np.float32 + elif isinstance(basetype, T.F16): + dtype = np.float16 + elif isinstance(basetype, T.F64): + dtype = np.float64 + elif isinstance(basetype, T.INT8): + dtype = np.int8 + elif isinstance(basetype, T.INT32): + dtype = np.int32 + elif isinstance(basetype, T.UINT8): + dtype = np.uint8 + elif isinstance(basetype, T.UINT16): + dtype = np.uint16 + + if arg.type.is_real_scalar(): + shape = (1,) + else: + shape = tuple( + eval_tensor_dimension(dim_expr, control_values) + for dim_expr in arg.type.shape() + ) + if dtype in [np.int8, np.int32]: + arg_values[str(arg.name)] = np.random.randint( + -INT_BOUND, INT_BOUND, shape, dtype=dtype + ) + elif dtype in [np.uint8, np.uint16]: + arg_values[str(arg.name)] = np.random.randint( + 0, INT_BOUND, shape, dtype=dtype + ) + elif dtype in [np.float16, np.float32, np.float64]: + arg_values[str(arg.name)] = ( + np.random.rand(*shape) * FLOAT_BOUND + ).astype(dtype) + + return arg_values + + +TEST_CASE_BOUND = 10 + + +def fuzz_reorder_stmts(s1, s2): + proc = s1.get_root() + proc_type_visitor = TypeVisitor({}) + proc_type_visitor.visit(proc) + cm = ConstraintMaker(proc_type_visitor.type_map) + constraint = Constraint(()) + for pred in proc.preds: + constraint = ConjunctionConstraint(constraint, cm.make_constraint(pred)) + constraint = ConjunctionConstraint(constraint, collect_path_constraints(s1, cm)) + args = [ + LoopIR.fnarg(name=var, type=arg_type, mem=DRAM, srcinfo=SrcInfo("", 0)) + for var, arg_type in get_free_variables( + proc_type_visitor.type_map, [s1._node, s2._node] + ).items() + ] + args = [arg for arg in args if not arg.type.is_numeric()] + [ + arg for arg in args if arg.type.is_numeric() + ] + for _ in range(TEST_CASE_BOUND): + arg_vals1 = generate_args(args, constraint, cm) + arg_vals2 = { + key: val.copy() if isinstance(val, np.ndarray) else val + for key, val in arg_vals1.items() + } + + run_interpreter( + LoopIR.proc( + name=proc.name, + args=args, + preds=[], + body=[s1._node, s2._node], + instr=None, + srcinfo=proc.srcinfo, + ), + arg_vals1, + ) + run_interpreter( + LoopIR.proc( + name=proc.name, + args=args, + preds=[], + body=[s2._node, s1._node], + instr=None, + srcinfo=proc.srcinfo, + ), + arg_vals2, + ) + for x in arg_vals1: + if not np.allclose(arg_vals1[x], arg_vals2[x]): + raise SchedulingError("mismatch found") diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py new file mode 100644 index 000000000..f91733444 --- /dev/null +++ b/src/exo/rewrite/constraint_solver.py @@ -0,0 +1,359 @@ +from dataclasses import dataclass +from typing import Union, Optional + +from exo.core.prelude import Sym +from ..core.LoopIR import LoopIR, T +import numpy as np + + +@dataclass +class Range: + lower_bound: Optional[int] + upper_bound: Optional[int] + + def intersect(self, other): + def wrap_none(option): + return [option] if option is not None else [] + + lbs = [*wrap_none(self.lower_bound), *wrap_none(other.lower_bound)] + ubs = [*wrap_none(self.upper_bound), *wrap_none(other.upper_bound)] + return Range( + None if len(lbs) == 0 else max(lbs), + None if len(ubs) == 0 else min(ubs), + ) + + +def simplify_disjunction(ranges: tuple[Range]) -> tuple[Range]: + bounds: list[tuple[Range, bool]] = [] + for r in ranges: + if ( + r.lower_bound is None + or r.upper_bound is None + or r.lower_bound <= r.upper_bound + ): + bounds.append((r, False)) + bounds.append((r, True)) + + def key(pair: tuple[Range, bool]) -> tuple[int, int, bool]: + r, is_upper = pair + if is_upper and r.upper_bound is None: + return (1, 0, is_upper) + if not is_upper and r.lower_bound is None: + return (-1, 0, is_upper) + return (0, r.upper_bound if is_upper else r.lower_bound, is_upper) + + bounds.sort(key=key) + + nest_depth = 0 + current_lower: Optional[int] = None + new_ranges: list[Range] = [] + for r, is_upper in bounds: + if nest_depth == 0: + assert not is_upper + current_lower = r.lower_bound + nest_depth += -1 if is_upper else 1 + if nest_depth == 0: + assert is_upper + new_ranges.append(Range(current_lower, r.upper_bound)) + return tuple(new_ranges) + + +@dataclass +class ConstraintTerm: + coefficient: int + syms: tuple[Sym] + + def negate(self) -> "ConstraintTerm": + return ConstraintTerm(-self.coefficient, self.syms) + + def multiply(self, other) -> "ConstraintTerm": + return ConstraintTerm( + self.coefficient * other.coefficient, self.syms + other.syms + ) + + def apply_assignments( + self, assignments: dict[Sym, int], target_sym: Sym + ) -> Optional[tuple[int, bool]]: + is_const = True + acc = self.coefficient + for sym in self.syms: + if sym == target_sym: + is_const = False + else: + if sym not in assignments: + return None + acc *= assignments[sym] + return (acc, is_const) + + +@dataclass +class Constraint: + terms: tuple[ConstraintTerm] + + def apply_assignments( + self, assignments: dict[Sym, int], target_sym: Sym + ) -> tuple[Range]: + offset, scale = 0, 0 + for term in self.terms: + assign_result = term.apply_assignments(assignments, target_sym) + if assign_result is None: + return (Range(None, None),) + else: + acc, is_const = assign_result + if is_const: + offset += acc + else: + scale += acc + if scale == 0: + if offset >= 0: + return (Range(None, None),) + else: + return (Range(0, -1),) + elif scale > 0: + return (Range(-offset / scale, None),) + else: + return (Range(None, -offset / scale),) + + def collect_syms(self) -> frozenset[Sym]: + return frozenset(sym for term in self.terms for sym in term.syms) + + def pretty_print(self) -> str: + return ( + " + ".join( + [ + f"{' * '.join([str(term.coefficient)] + [repr(sym) for sym in term.syms])}" + for term in self.terms + ] + ) + + " >= 0" + ) + + +GenericConstraint = Union[Constraint, "ConjunctionConstraint", "DisjunctionConstraint"] + + +@dataclass +class ConjunctionConstraint: + lhs: GenericConstraint + rhs: GenericConstraint + + def apply_assignments( + self, assignments: dict[Sym, int], target_sym: Sym + ) -> tuple[Range]: + lhs_ranges = self.lhs.apply_assignments(assignments, target_sym) + rhs_ranges = self.rhs.apply_assignments(assignments, target_sym) + return simplify_disjunction( + tuple( + lhs_range.intersect(rhs_range) + for lhs_range in lhs_ranges + for rhs_range in rhs_ranges + ) + ) + + def collect_syms(self) -> frozenset[Sym]: + return self.lhs.collect_syms() | self.rhs.collect_syms() + + def pretty_print(self) -> str: + return f"({self.lhs.pretty_print()}) and ({self.rhs.pretty_print()})" + + +@dataclass +class DisjunctionConstraint: + lhs: Constraint + rhs: Constraint + + def apply_assignments( + self, assignments: dict[Sym, int], target_sym: Sym + ) -> tuple[Range]: + lhs_ranges = self.lhs.apply_assignments(assignments, target_sym) + rhs_ranges = self.rhs.apply_assignments(assignments, target_sym) + return simplify_disjunction(lhs_ranges + rhs_ranges) + + def collect_syms(self) -> frozenset[Sym]: + return self.lhs.collect_syms() | self.rhs.collect_syms() + + def pretty_print(self) -> str: + return f"({self.lhs.pretty_print()}) or ({self.rhs.pretty_print()})" + + +class ConstraintMaker: + def __init__(self, type_map: dict[Sym, LoopIR.type]): + self.nonneg_vars = set( + sym + for sym, sym_type in type_map.items() + if isinstance(sym_type, (T.Size, T.Index)) + ) + self.bool_vars = set( + sym for sym, sym_type in type_map.items() if isinstance(sym_type, (T.Bool)) + ) + self.stride_dummies: dict[tuple[Sym, int], Sym] = {} + + def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: + # expect that expr is int type + if isinstance(expr, LoopIR.Read): + assert ( + len(expr.idx) == 0 + ), "indexing not supported in assertions (yet, todo)" + return (ConstraintTerm(1, (expr.name,)),) + elif isinstance(expr, LoopIR.Const): + return (ConstraintTerm(expr.val, ()),) + elif isinstance(expr, LoopIR.USub): + return tuple(term.negate() for term in self.make_constraint_terms(expr.arg)) + elif isinstance(expr, LoopIR.BinOp): + # TODO: support mod and div using extra variables + lhs_terms = self.make_constraint_terms(expr.lhs) + rhs_terms = self.make_constraint_terms(expr.rhs) + if expr.op == "+": + return lhs_terms + rhs_terms + elif expr.op == "-": + return lhs_terms + tuple(term.negate() for term in rhs_terms) + elif expr.op == "*": + return tuple( + lhs_term.multiply(rhs_term) + for lhs_term in lhs_terms + for rhs_term in rhs_terms + ) + else: + assert False, f"unsupported op in assertion: {expr.op}" + elif isinstance(expr, LoopIR.StrideExpr): + if (expr.name, expr.dim) not in self.stride_dummies: + new_sym = Sym("stride") + self.stride_dummies[(expr.name, expr.dim)] = new_sym + self.nonneg_vars.add(new_sym) + dummy = self.stride_dummies[(expr.name, expr.dim)] + return (ConstraintTerm(1, (dummy,)),) + else: + assert False, f"unsupported expr" + + def make_constraint( + self, + expr: LoopIR.expr, + ) -> GenericConstraint: + # expect that expr is bool type + if isinstance(expr, LoopIR.BinOp): + if expr.op == "and": + return ConjunctionConstraint( + self.make_constraint(expr.lhs), self.make_constraint(expr.rhs) + ) + elif expr.op == "or": + return DisjunctionConstraint( + self.make_constraint(expr.lhs), self.make_constraint(expr.rhs) + ) + elif expr.op == "<": + return Constraint( + self.make_constraint_terms(expr.rhs) + + tuple( + term.negate() for term in self.make_constraint_terms(expr.lhs) + ) + + (ConstraintTerm(-1, ()),) + ) + elif expr.op == ">": + return Constraint( + self.make_constraint_terms(expr.lhs) + + tuple( + term.negate() for term in self.make_constraint_terms(expr.rhs) + ) + + (ConstraintTerm(-1, ()),) + ) + elif expr.op == "<=": + return Constraint( + self.make_constraint_terms(expr.rhs) + + tuple( + term.negate() for term in self.make_constraint_terms(expr.lhs) + ) + ) + elif expr.op == ">=": + return Constraint( + self.make_constraint_terms(expr.lhs) + + tuple( + term.negate() for term in self.make_constraint_terms(expr.rhs) + ) + ) + elif expr.op == "==": + return ConjunctionConstraint( + Constraint( + self.make_constraint_terms(expr.rhs) + + tuple( + term.negate() + for term in self.make_constraint_terms(expr.lhs) + ) + ), + Constraint( + self.make_constraint_terms(expr.lhs) + + tuple( + term.negate() + for term in self.make_constraint_terms(expr.rhs) + ) + ), + ) + else: + assert False, "boolean ops expected" + elif isinstance(expr, LoopIR.Read): + assert len(expr.idx) == 0, "cannot index into boolean" + return ConjunctionConstraint( + Constraint((ConstraintTerm(1, expr.name), ConstraintTerm(-1, ()))), + Constraint((ConstraintTerm(-1, expr.name), ConstraintTerm(1, ()))), + ) + else: + assert False, "only boolean expected" + + def solve_constraint( + self, constraint: GenericConstraint, bound: int, seed: Optional[int] = None + ): + if seed is not None: + np.random.seed(seed=seed) + assignments = {} + syms = constraint.collect_syms() + + bounding_range = Range(-bound, bound) + + def solve_recursive() -> bool: + sym_domains = [ + ( + tuple( + sym_range.intersect( + Range(0, bound) + if sym in self.nonneg_vars + else ( + Range(0, 1) + if sym in self.bool_vars + else Range(-bound, bound) + ) + ) + for sym_range in constraint.apply_assignments(assignments, sym) + ), + sym, + ) + for sym in syms - assignments.keys() + ] + if len(sym_domains) == 0: + return True + else: + + def domain_size(sym_domain: tuple[Range]) -> int: + return sum( + sym_range.upper_bound - sym_range.lower_bound + 1 + for sym_range in sym_domain + ) + + sym_domains.sort(key=lambda sym_domain: domain_size(sym_domain[0])) + sym_domain, sym = sym_domains[0] + if len(sym_domain) == 0: + return False + range_sizes = np.array( + [ + sym_range.upper_bound - sym_range.lower_bound + 1 + for sym_range in sym_domain + ] + ) + chosen_range = np.random.choice( + sym_domain, p=range_sizes / np.sum(range_sizes) + ) + assignments[sym] = np.random.randint( + chosen_range.lower_bound, chosen_range.upper_bound + 1 + ) + return solve_recursive() + + while not solve_recursive(): + assignments = {} + return assignments diff --git a/tests/golden/test_constraint_solver/test_make_constraint.txt b/tests/golden/test_constraint_solver/test_make_constraint.txt new file mode 100644 index 000000000..7f91efd63 --- /dev/null +++ b/tests/golden/test_constraint_solver/test_make_constraint.txt @@ -0,0 +1 @@ +((4 * a_1 + 1 * b_2 + -1 * c_3 + -1 >= 0) or (3 + -1 * a_1 >= 0)) and (5 + -1 * b_2 + -1 >= 0) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_solve.txt b/tests/golden/test_constraint_solver/test_solve.txt new file mode 100644 index 000000000..a263a2a17 --- /dev/null +++ b/tests/golden/test_constraint_solver/test_solve.txt @@ -0,0 +1 @@ +{b_5: 2, a_4: 2, c_6: 12} \ No newline at end of file diff --git a/tests/test_constraint_solver.py b/tests/test_constraint_solver.py new file mode 100644 index 000000000..9afcd1474 --- /dev/null +++ b/tests/test_constraint_solver.py @@ -0,0 +1,36 @@ +from __future__ import annotations +from exo.core.prelude import Sym + +from exo.rewrite.constraint_solver import ConstraintMaker +from exo.core.LoopIR import LoopIR +from exo import proc +from exo.rewrite.chexo import TypeVisitor + + +def test_make_constraint(golden): + @proc + def foo(a: size, b: size, c: size): + assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) + pass + + foo_type = TypeVisitor({}) + foo_type.visit(foo._loopir_proc) + assert ( + golden + == ConstraintMaker(foo_type.type_map) + .make_constraint(foo._loopir_proc.preds[0]) + .pretty_print() + ) + + +def test_solve(golden): + @proc + def foo(a: size, b: size, c: size): + assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) + pass + + foo_type = TypeVisitor({}) + foo_type.visit(foo._loopir_proc) + cm = ConstraintMaker(foo_type.type_map) + constraint = cm.make_constraint(foo._loopir_proc.preds[0]) + assert golden == str(cm.solve_constraint(constraint, 16, 13)) From cb68959347fb7cc2331ed13b742c12a32cc83d64 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Fri, 7 Feb 2025 14:44:42 -0500 Subject: [PATCH 02/24] avoid sym repr in tests --- src/exo/rewrite/constraint_solver.py | 2 +- .../golden/test_constraint_solver/test_make_constraint.txt | 2 +- tests/golden/test_constraint_solver/test_solve.txt | 2 +- tests/test_constraint_solver.py | 7 ++++++- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index f91733444..ecad9a005 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -121,7 +121,7 @@ def pretty_print(self) -> str: return ( " + ".join( [ - f"{' * '.join([str(term.coefficient)] + [repr(sym) for sym in term.syms])}" + f"{' * '.join([str(term.coefficient)] + [str(sym) for sym in term.syms])}" for term in self.terms ] ) diff --git a/tests/golden/test_constraint_solver/test_make_constraint.txt b/tests/golden/test_constraint_solver/test_make_constraint.txt index 7f91efd63..7291f502e 100644 --- a/tests/golden/test_constraint_solver/test_make_constraint.txt +++ b/tests/golden/test_constraint_solver/test_make_constraint.txt @@ -1 +1 @@ -((4 * a_1 + 1 * b_2 + -1 * c_3 + -1 >= 0) or (3 + -1 * a_1 >= 0)) and (5 + -1 * b_2 + -1 >= 0) \ No newline at end of file +((4 * a + 1 * b + -1 * c + -1 >= 0) or (3 + -1 * a >= 0)) and (5 + -1 * b + -1 >= 0) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_solve.txt b/tests/golden/test_constraint_solver/test_solve.txt index a263a2a17..8d0755a0f 100644 --- a/tests/golden/test_constraint_solver/test_solve.txt +++ b/tests/golden/test_constraint_solver/test_solve.txt @@ -1 +1 @@ -{b_5: 2, a_4: 2, c_6: 12} \ No newline at end of file +b = 2, a = 2, c = 12 \ No newline at end of file diff --git a/tests/test_constraint_solver.py b/tests/test_constraint_solver.py index 9afcd1474..d80ec9629 100644 --- a/tests/test_constraint_solver.py +++ b/tests/test_constraint_solver.py @@ -33,4 +33,9 @@ def foo(a: size, b: size, c: size): foo_type.visit(foo._loopir_proc) cm = ConstraintMaker(foo_type.type_map) constraint = cm.make_constraint(foo._loopir_proc.preds[0]) - assert golden == str(cm.solve_constraint(constraint, 16, 13)) + assert golden == ", ".join( + [ + f"{str(sym)} = {val}" + for sym, val in cm.solve_constraint(constraint, 16, 13).items() + ] + ) From e901372bab3f0a622830a53ff4b2c4c0722f32e9 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Fri, 7 Feb 2025 16:18:46 -0500 Subject: [PATCH 03/24] add div and mod to constraint solver --- src/exo/rewrite/chexo.py | 34 ++++--- src/exo/rewrite/constraint_solver.py | 96 +++++++++++++------ .../test_constraint_solver/test_divmod.txt | 1 + .../test_divmod_solve.txt | 1 + tests/test_constraint_solver.py | 35 ++++++- 5 files changed, 126 insertions(+), 41 deletions(-) create mode 100644 tests/golden/test_constraint_solver/test_divmod.txt create mode 100644 tests/golden/test_constraint_solver/test_divmod_solve.txt diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index 171d6eda6..c23754099 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -1,8 +1,9 @@ from typing import Optional + from ..core.LoopIR import LoopIR, T -from dataclasses import dataclass +from dataclasses import dataclass, field from ..core.prelude import Sym, SrcInfo -from ..core.memory import DRAM +from ..core.memory import DRAM, Memory from ..backend.LoopIR_interpreter import run_interpreter import numpy as np from .new_eff import SchedulingError @@ -42,7 +43,8 @@ def visit_generic(self, node): @dataclass class TypeVisitor(LoopIRVisitor): - type_map: dict[Sym, LoopIR.type] + type_map: dict[Sym, LoopIR.type] = field(default_factory=lambda: {}) + mem_map: dict[Sym, Memory] = field(default_factory=lambda: {}) def visit(self, node): if isinstance(node, LoopIR.For): @@ -50,17 +52,20 @@ def visit(self, node): self.visit_generic(node) elif isinstance(node, LoopIR.Alloc): self.type_map[node.name] = node.type + self.mem_map[node.name] = node.mem elif isinstance(node, LoopIR.WindowStmt): self.type_map[node.name] = node.rhs.type elif isinstance(node, LoopIR.fnarg): self.type_map[node.name] = node.type + if node.mem: + self.mem_map[node.name] = node.mem else: self.visit_generic(node) @dataclass class UsedVariableVisitor(LoopIRVisitor): - used_vars: set[Sym] + used_vars: set[Sym] = field(default_factory=lambda: set()) def visit(self, node): if isinstance(node, Sym): @@ -69,16 +74,16 @@ def visit(self, node): self.visit_generic(node) -def get_free_variables(type_map, fragment): - fragment_type_visitor = TypeVisitor({}) - fragment_var_visitor = UsedVariableVisitor(set()) +def get_free_variables(type_map, mem_map, fragment): + fragment_type_visitor = TypeVisitor() + fragment_var_visitor = UsedVariableVisitor() for stmt in fragment: fragment_type_visitor.visit(stmt) fragment_var_visitor.visit(stmt) for var in fragment_var_visitor.used_vars - fragment_type_visitor.type_map.keys(): fragment_var_visitor.visit(type_map[var]) return { - var: type_map[var] + var: (type_map[var], mem_map[var] if var in mem_map else None) for var in fragment_var_visitor.used_vars - fragment_type_visitor.type_map.keys() } @@ -223,7 +228,7 @@ def generate_args(args, constraint: Constraint, cm: ConstraintMaker): def fuzz_reorder_stmts(s1, s2): proc = s1.get_root() - proc_type_visitor = TypeVisitor({}) + proc_type_visitor = TypeVisitor() proc_type_visitor.visit(proc) cm = ConstraintMaker(proc_type_visitor.type_map) constraint = Constraint(()) @@ -231,9 +236,14 @@ def fuzz_reorder_stmts(s1, s2): constraint = ConjunctionConstraint(constraint, cm.make_constraint(pred)) constraint = ConjunctionConstraint(constraint, collect_path_constraints(s1, cm)) args = [ - LoopIR.fnarg(name=var, type=arg_type, mem=DRAM, srcinfo=SrcInfo("", 0)) - for var, arg_type in get_free_variables( - proc_type_visitor.type_map, [s1._node, s2._node] + LoopIR.fnarg( + name=var, + type=arg_type, + mem=DRAM if arg_mem is None else arg_mem, + srcinfo=SrcInfo("", 0), + ) + for var, (arg_type, arg_mem) in get_free_variables( + proc_type_visitor.type_map, proc_type_visitor.mem_map, [s1._node, s2._node] ).items() ] args = [arg for arg in args if not arg.type.is_numeric()] + [ diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index ecad9a005..c1e198ae6 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -110,9 +110,9 @@ def apply_assignments( else: return (Range(0, -1),) elif scale > 0: - return (Range(-offset / scale, None),) + return (Range(int(np.ceil(-offset / scale)), None),) else: - return (Range(None, -offset / scale),) + return (Range(None, int(np.floor(-offset / scale))),) def collect_syms(self) -> frozenset[Sym]: return frozenset(sym for term in self.terms for sym in term.syms) @@ -186,6 +186,7 @@ def __init__(self, type_map: dict[Sym, LoopIR.type]): self.bool_vars = set( sym for sym, sym_type in type_map.items() if isinstance(sym_type, (T.Bool)) ) + self.div_constraint = Constraint(()) self.stride_dummies: dict[tuple[Sym, int], Sym] = {} def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: @@ -213,6 +214,49 @@ def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: for lhs_term in lhs_terms for rhs_term in rhs_terms ) + elif expr.op in ["/", "%"]: + div, rem = Sym("div"), Sym("rem") + div_terms = ( + tuple( + ConstraintTerm(term.coefficient, term.syms + (div,)) + for term in rhs_terms + ) + + tuple(term.negate() for term in lhs_terms) + + (ConstraintTerm(1, (rem,)),) + ) + self.div_constraint = ConjunctionConstraint( + self.div_constraint, + ConjunctionConstraint( + ConjunctionConstraint( + Constraint(div_terms), + Constraint(tuple(term.negate() for term in div_terms)), + ), + DisjunctionConstraint( + ConjunctionConstraint( + Constraint((ConstraintTerm(1, (rem,)),)), + Constraint( + rhs_terms + + ( + ConstraintTerm(-1, (rem,)), + ConstraintTerm(-1, ()), + ) + ), + ), + ConjunctionConstraint( + Constraint((ConstraintTerm(-1, (rem,)),)), + Constraint( + tuple(term.negate() for term in rhs_terms) + + ( + ConstraintTerm(1, (rem,)), + ConstraintTerm(1, ()), + ) + ), + ), + ), + ), + ) + + return (ConstraintTerm(1, (div if expr.op == "/" else rem,)),) else: assert False, f"unsupported op in assertion: {expr.op}" elif isinstance(expr, LoopIR.StrideExpr): @@ -270,21 +314,11 @@ def make_constraint( ) ) elif expr.op == "==": + lhs_terms = self.make_constraint_terms(expr.lhs) + rhs_terms = self.make_constraint_terms(expr.rhs) return ConjunctionConstraint( - Constraint( - self.make_constraint_terms(expr.rhs) - + tuple( - term.negate() - for term in self.make_constraint_terms(expr.lhs) - ) - ), - Constraint( - self.make_constraint_terms(expr.lhs) - + tuple( - term.negate() - for term in self.make_constraint_terms(expr.rhs) - ) - ), + Constraint(rhs_terms + tuple(term.negate() for term in lhs_terms)), + Constraint(lhs_terms + tuple(term.negate() for term in rhs_terms)), ) else: assert False, "boolean ops expected" @@ -294,6 +328,11 @@ def make_constraint( Constraint((ConstraintTerm(1, expr.name), ConstraintTerm(-1, ()))), Constraint((ConstraintTerm(-1, expr.name), ConstraintTerm(1, ()))), ) + elif isinstance(expr, LoopIR.Const): + if expr.val: + return Constraint(()) + else: + return Constraint((ConstraintTerm(-1, ()))) else: assert False, "only boolean expected" @@ -302,25 +341,28 @@ def solve_constraint( ): if seed is not None: np.random.seed(seed=seed) + constraint = ConjunctionConstraint(constraint, self.div_constraint) assignments = {} syms = constraint.collect_syms() - bounding_range = Range(-bound, bound) - def solve_recursive() -> bool: sym_domains = [ ( - tuple( - sym_range.intersect( - Range(0, bound) - if sym in self.nonneg_vars - else ( - Range(0, 1) - if sym in self.bool_vars - else Range(-bound, bound) + simplify_disjunction( + tuple( + sym_range.intersect( + Range(0, bound) + if sym in self.nonneg_vars + else ( + Range(0, 1) + if sym in self.bool_vars + else Range(-bound, bound) + ) + ) + for sym_range in constraint.apply_assignments( + assignments, sym ) ) - for sym_range in constraint.apply_assignments(assignments, sym) ), sym, ) diff --git a/tests/golden/test_constraint_solver/test_divmod.txt b/tests/golden/test_constraint_solver/test_divmod.txt new file mode 100644 index 000000000..4887b6f37 --- /dev/null +++ b/tests/golden/test_constraint_solver/test_divmod.txt @@ -0,0 +1 @@ +(((4 * a + 1 * b + -1 * c + -1 >= 0) or (3 + -1 * a >= 0)) and (5 + -1 * b + -1 >= 0)) and ((3 + -1 * rem >= 0) and (1 * rem + -3 >= 0)) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_divmod_solve.txt b/tests/golden/test_constraint_solver/test_divmod_solve.txt new file mode 100644 index 000000000..b8d548a89 --- /dev/null +++ b/tests/golden/test_constraint_solver/test_divmod_solve.txt @@ -0,0 +1 @@ +rem = 3, b = 4, a = 11, div = 2, c = 11 \ No newline at end of file diff --git a/tests/test_constraint_solver.py b/tests/test_constraint_solver.py index d80ec9629..cfe889390 100644 --- a/tests/test_constraint_solver.py +++ b/tests/test_constraint_solver.py @@ -13,7 +13,7 @@ def foo(a: size, b: size, c: size): assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) pass - foo_type = TypeVisitor({}) + foo_type = TypeVisitor() foo_type.visit(foo._loopir_proc) assert ( golden @@ -29,7 +29,38 @@ def foo(a: size, b: size, c: size): assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) pass - foo_type = TypeVisitor({}) + foo_type = TypeVisitor() + foo_type.visit(foo._loopir_proc) + cm = ConstraintMaker(foo_type.type_map) + constraint = cm.make_constraint(foo._loopir_proc.preds[0]) + assert golden == ", ".join( + [ + f"{str(sym)} = {val}" + for sym, val in cm.solve_constraint(constraint, 16, 13).items() + ] + ) + + +def test_divmod(golden): + @proc + def foo(a: size, b: size, c: size): + assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) and (a % 4 == 3) + pass + + foo_type = TypeVisitor() + foo_type.visit(foo._loopir_proc) + cm = ConstraintMaker(foo_type.type_map) + constraint = cm.make_constraint(foo._loopir_proc.preds[0]) + assert golden == constraint.pretty_print() + + +def test_divmod_solve(golden): + @proc + def foo(a: size, b: size, c: size): + assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) and (a % 4 == 3) + pass + + foo_type = TypeVisitor() foo_type.visit(foo._loopir_proc) cm = ConstraintMaker(foo_type.type_map) constraint = cm.make_constraint(foo._loopir_proc.preds[0]) From d8febabb2328af32a42179915eb1a4d5c9428565 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 18 Feb 2025 15:05:47 -0500 Subject: [PATCH 04/24] Collect ilp constraints --- src/exo/rewrite/chexo.py | 24 +- src/exo/rewrite/constraint_solver.py | 460 +++++++++++---------------- tests/test_constraint_solver.py | 16 +- 3 files changed, 207 insertions(+), 293 deletions(-) diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index c23754099..37719a883 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -95,11 +95,11 @@ def eval_tensor_dimension(dim_expr, control_values): elif isinstance(dim_expr, LoopIR.Const): return dim_expr.val elif isinstance(dim_expr, LoopIR.USub): - return -eval_tensor_dimension(dim_expr.arg) + return -eval_tensor_dimension(dim_expr.arg, control_values) elif isinstance(dim_expr, LoopIR.BinOp): - lhs, rhs = eval_tensor_dimension(dim_expr.lhs), eval_tensor_dimension( - dim_expr.rhs - ) + lhs, rhs = eval_tensor_dimension( + dim_expr.lhs, control_values + ), eval_tensor_dimension(dim_expr.rhs, control_values) if dim_expr.op == "+": return lhs + rhs elif dim_expr.op == "-": @@ -134,7 +134,8 @@ def eval_tensor_dimension(dim_expr, control_values): assert False, "unexpected expression type in tensor dimension" -CONTROL_VAL_BOUND = 16 +CONTROL_VAL_BOUND = 128 +SEARCH_LIMIT = 10 INT_BOUND = 128 FLOAT_BOUND = 32 @@ -162,7 +163,7 @@ def collect_path_constraints(cursor, cm: ConstraintMaker) -> GenericConstraint: ), ) elif isinstance(cur._node, LoopIR.If): - result = ConjunctionConstraint(result, cm.make_constraint(cur._node.cond)) + result = ConjunctionConstraint(result, cm.make_constraints(cur._node.cond)) cur = cur.parent() return result @@ -170,7 +171,11 @@ def collect_path_constraints(cursor, cm: ConstraintMaker) -> GenericConstraint: def generate_args(args, constraint: Constraint, cm: ConstraintMaker): arg_values = {} control_values = {} - assignments = cm.solve_constraint(constraint, CONTROL_VAL_BOUND) + assignments = cm.solve_constraint( + constraint, bound=CONTROL_VAL_BOUND, search_limit=SEARCH_LIMIT + ) + if assignments is None: + return None for arg in args: if not arg.type.is_numeric(): if arg.name in assignments: @@ -227,13 +232,14 @@ def generate_args(args, constraint: Constraint, cm: ConstraintMaker): def fuzz_reorder_stmts(s1, s2): + print(s1) proc = s1.get_root() proc_type_visitor = TypeVisitor() proc_type_visitor.visit(proc) cm = ConstraintMaker(proc_type_visitor.type_map) constraint = Constraint(()) for pred in proc.preds: - constraint = ConjunctionConstraint(constraint, cm.make_constraint(pred)) + constraint = ConjunctionConstraint(constraint, cm.make_constraints(pred)) constraint = ConjunctionConstraint(constraint, collect_path_constraints(s1, cm)) args = [ LoopIR.fnarg( @@ -251,6 +257,8 @@ def fuzz_reorder_stmts(s1, s2): ] for _ in range(TEST_CASE_BOUND): arg_vals1 = generate_args(args, constraint, cm) + if arg_vals1 is None: + continue arg_vals2 = { key: val.copy() if isinstance(val, np.ndarray) else val for key, val in arg_vals1.items() diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index c1e198ae6..4509a57ca 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Union, Optional from exo.core.prelude import Sym @@ -6,58 +6,6 @@ import numpy as np -@dataclass -class Range: - lower_bound: Optional[int] - upper_bound: Optional[int] - - def intersect(self, other): - def wrap_none(option): - return [option] if option is not None else [] - - lbs = [*wrap_none(self.lower_bound), *wrap_none(other.lower_bound)] - ubs = [*wrap_none(self.upper_bound), *wrap_none(other.upper_bound)] - return Range( - None if len(lbs) == 0 else max(lbs), - None if len(ubs) == 0 else min(ubs), - ) - - -def simplify_disjunction(ranges: tuple[Range]) -> tuple[Range]: - bounds: list[tuple[Range, bool]] = [] - for r in ranges: - if ( - r.lower_bound is None - or r.upper_bound is None - or r.lower_bound <= r.upper_bound - ): - bounds.append((r, False)) - bounds.append((r, True)) - - def key(pair: tuple[Range, bool]) -> tuple[int, int, bool]: - r, is_upper = pair - if is_upper and r.upper_bound is None: - return (1, 0, is_upper) - if not is_upper and r.lower_bound is None: - return (-1, 0, is_upper) - return (0, r.upper_bound if is_upper else r.lower_bound, is_upper) - - bounds.sort(key=key) - - nest_depth = 0 - current_lower: Optional[int] = None - new_ranges: list[Range] = [] - for r, is_upper in bounds: - if nest_depth == 0: - assert not is_upper - current_lower = r.lower_bound - nest_depth += -1 if is_upper else 1 - if nest_depth == 0: - assert is_upper - new_ranges.append(Range(current_lower, r.upper_bound)) - return tuple(new_ranges) - - @dataclass class ConstraintTerm: coefficient: int @@ -72,18 +20,35 @@ def multiply(self, other) -> "ConstraintTerm": ) def apply_assignments( - self, assignments: dict[Sym, int], target_sym: Sym - ) -> Optional[tuple[int, bool]]: - is_const = True + self, assignments: dict[Sym, int] + ) -> Optional[tuple[int, Optional[Sym]]]: + target_sym = None acc = self.coefficient for sym in self.syms: - if sym == target_sym: - is_const = False + if sym in assignments: + acc *= assignments[sym] else: - if sym not in assignments: + if target_sym is None: + target_sym = sym + else: return None - acc *= assignments[sym] - return (acc, is_const) + return (acc, target_sym) + + def collect_nonlinear_syms(self) -> frozenset[Sym]: + occurrences = set() + result = set() + for sym in self.syms: + if sym in occurrences: + result.add(sym) + else: + occurrences.add(sym) + return frozenset(result) + + +@dataclass +class LinearConstraint: + coefficients: dict[Sym, int] + offset: int @dataclass @@ -91,32 +56,32 @@ class Constraint: terms: tuple[ConstraintTerm] def apply_assignments( - self, assignments: dict[Sym, int], target_sym: Sym - ) -> tuple[Range]: - offset, scale = 0, 0 + self, assignments: dict[Sym, int] + ) -> Optional[LinearConstraint]: + coefficients = {} + offset = 0 for term in self.terms: - assign_result = term.apply_assignments(assignments, target_sym) + assign_result = term.apply_assignments(assignments) if assign_result is None: - return (Range(None, None),) + return None else: - acc, is_const = assign_result - if is_const: - offset += acc + coefficient, sym = assign_result + if sym is None: + offset += coefficient else: - scale += acc - if scale == 0: - if offset >= 0: - return (Range(None, None),) - else: - return (Range(0, -1),) - elif scale > 0: - return (Range(int(np.ceil(-offset / scale)), None),) - else: - return (Range(None, int(np.floor(-offset / scale))),) + if sym not in coefficients: + coefficients[sym] = 0 + coefficients[sym] += coefficient + return LinearConstraint(coefficients, offset) def collect_syms(self) -> frozenset[Sym]: return frozenset(sym for term in self.terms for sym in term.syms) + def collect_nonlinear_syms(self) -> frozenset[Sym]: + return frozenset().union( + *[term.collect_nonlinear_syms() for term in self.terms] + ) + def pretty_print(self) -> str: return ( " + ".join( @@ -125,69 +90,45 @@ def pretty_print(self) -> str: for term in self.terms ] ) - + " >= 0" - ) - - -GenericConstraint = Union[Constraint, "ConjunctionConstraint", "DisjunctionConstraint"] - - -@dataclass -class ConjunctionConstraint: - lhs: GenericConstraint - rhs: GenericConstraint - - def apply_assignments( - self, assignments: dict[Sym, int], target_sym: Sym - ) -> tuple[Range]: - lhs_ranges = self.lhs.apply_assignments(assignments, target_sym) - rhs_ranges = self.rhs.apply_assignments(assignments, target_sym) - return simplify_disjunction( - tuple( - lhs_range.intersect(rhs_range) - for lhs_range in lhs_ranges - for rhs_range in rhs_ranges - ) + + " == 0" ) - def collect_syms(self) -> frozenset[Sym]: - return self.lhs.collect_syms() | self.rhs.collect_syms() - - def pretty_print(self) -> str: - return f"({self.lhs.pretty_print()}) and ({self.rhs.pretty_print()})" - - -@dataclass -class DisjunctionConstraint: - lhs: Constraint - rhs: Constraint - - def apply_assignments( - self, assignments: dict[Sym, int], target_sym: Sym - ) -> tuple[Range]: - lhs_ranges = self.lhs.apply_assignments(assignments, target_sym) - rhs_ranges = self.rhs.apply_assignments(assignments, target_sym) - return simplify_disjunction(lhs_ranges + rhs_ranges) - - def collect_syms(self) -> frozenset[Sym]: - return self.lhs.collect_syms() | self.rhs.collect_syms() - - def pretty_print(self) -> str: - return f"({self.lhs.pretty_print()}) or ({self.rhs.pretty_print()})" - class ConstraintMaker: def __init__(self, type_map: dict[Sym, LoopIR.type]): - self.nonneg_vars = set( - sym - for sym, sym_type in type_map.items() - if isinstance(sym_type, (T.Size, T.Index)) - ) - self.bool_vars = set( - sym for sym, sym_type in type_map.items() if isinstance(sym_type, (T.Bool)) - ) - self.div_constraint = Constraint(()) + self.unconstrained_var_subs: dict[Sym, tuple[ConstraintTerm]] = {} + self.extra_constraints: list[Constraint] = [] self.stride_dummies: dict[tuple[Sym, int], Sym] = {} + for sym, sym_type in type_map.items(): + if isinstance(sym_type, T.Size, T.Stride): + # positive constraint + self.extra_constraints.append( + Constraint( + ( + ConstraintTerm(1, (sym,)), + ConstraintTerm(-1, ()), + ConstraintTerm(-1, (Sym("slack"),)), + ) + ) + ) + elif isinstance(sym_type, T.Int, T.Num): + # unsigned variables are represented as a - b, where a and b are nonnegative + a, b = Sym("a"), Sym("b") + self.unconstrained_var_subs[sym] = ( + ConstraintTerm(1, (a,)), + ConstraintTerm(-1, (b,)), + ) + elif isinstance(sym_type, T.Bool): + # constrained to [0, 1] + self.extra_constraints.append( + Constraint( + ( + ConstraintTerm(1, (sym,)), + ConstraintTerm(-1, ()), + ConstraintTerm(1, (Sym("slack"))), + ) + ) + ) def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: # expect that expr is int type @@ -195,7 +136,10 @@ def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: assert ( len(expr.idx) == 0 ), "indexing not supported in assertions (yet, todo)" - return (ConstraintTerm(1, (expr.name,)),) + if expr.name in self.unconstrained_var_subs: + return self.unconstrained_var_subs[expr.name] + else: + return (ConstraintTerm(1, (expr.name,)),) elif isinstance(expr, LoopIR.Const): return (ConstraintTerm(expr.val, ()),) elif isinstance(expr, LoopIR.USub): @@ -216,47 +160,26 @@ def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: ) elif expr.op in ["/", "%"]: div, rem = Sym("div"), Sym("rem") - div_terms = ( - tuple( - ConstraintTerm(term.coefficient, term.syms + (div,)) - for term in rhs_terms + self.extra_constraints.append( + Constraint( + lhs_terms + + (ConstraintTerm(1, (rem,))) + + tuple( + rhs_term.multiply(ConstraintTerm(1, (div,))) + for rhs_term in rhs_terms + ) ) - + tuple(term.negate() for term in lhs_terms) - + (ConstraintTerm(1, (rem,)),) ) - self.div_constraint = ConjunctionConstraint( - self.div_constraint, - ConjunctionConstraint( - ConjunctionConstraint( - Constraint(div_terms), - Constraint(tuple(term.negate() for term in div_terms)), - ), - DisjunctionConstraint( - ConjunctionConstraint( - Constraint((ConstraintTerm(1, (rem,)),)), - Constraint( - rhs_terms - + ( - ConstraintTerm(-1, (rem,)), - ConstraintTerm(-1, ()), - ) - ), - ), - ConjunctionConstraint( - Constraint((ConstraintTerm(-1, (rem,)),)), - Constraint( - tuple(term.negate() for term in rhs_terms) - + ( - ConstraintTerm(1, (rem,)), - ConstraintTerm(1, ()), - ) - ), - ), - ), - ), + self.extra_constraints.append( + Constraint( + ( + ConstraintTerm(-1, (rem,)), + ConstraintTerm(-1, (Sym("slack"),)), + ) + + rhs_terms + ) ) - - return (ConstraintTerm(1, (div if expr.op == "/" else rem,)),) + return (ConstraintTerm(1, (rem if expr.op == "%" else div,)),) else: assert False, f"unsupported op in assertion: {expr.op}" elif isinstance(expr, LoopIR.StrideExpr): @@ -269,133 +192,112 @@ def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: else: assert False, f"unsupported expr" - def make_constraint( + def make_constraints( self, expr: LoopIR.expr, - ) -> GenericConstraint: + ) -> tuple[Constraint]: # expect that expr is bool type if isinstance(expr, LoopIR.BinOp): if expr.op == "and": - return ConjunctionConstraint( - self.make_constraint(expr.lhs), self.make_constraint(expr.rhs) - ) + return self.make_constraints(expr.lhs) + self.make_constraints(expr.rhs) elif expr.op == "or": - return DisjunctionConstraint( - self.make_constraint(expr.lhs), self.make_constraint(expr.rhs) - ) - elif expr.op == "<": - return Constraint( - self.make_constraint_terms(expr.rhs) - + tuple( - term.negate() for term in self.make_constraint_terms(expr.lhs) - ) - + (ConstraintTerm(-1, ()),) - ) - elif expr.op == ">": - return Constraint( - self.make_constraint_terms(expr.lhs) - + tuple( - term.negate() for term in self.make_constraint_terms(expr.rhs) - ) - + (ConstraintTerm(-1, ()),) - ) - elif expr.op == "<=": - return Constraint( - self.make_constraint_terms(expr.rhs) - + tuple( - term.negate() for term in self.make_constraint_terms(expr.lhs) - ) - ) - elif expr.op == ">=": - return Constraint( - self.make_constraint_terms(expr.lhs) - + tuple( - term.negate() for term in self.make_constraint_terms(expr.rhs) + # disjunction multiplies all constraints + lhs_constraints, rhs_constraints = self.make_constraints( + expr.lhs + ), self.make_constraints(expr.rhs) + return tuple( + Constraint( + tuple( + lhs_term.multiply(rhs_term) + for lhs_term in lhs_constraint.terms + for rhs_term in rhs_constraint.terms + ) ) - ) - elif expr.op == "==": - lhs_terms = self.make_constraint_terms(expr.lhs) - rhs_terms = self.make_constraint_terms(expr.rhs) - return ConjunctionConstraint( - Constraint(rhs_terms + tuple(term.negate() for term in lhs_terms)), - Constraint(lhs_terms + tuple(term.negate() for term in rhs_terms)), + for lhs_constraint in lhs_constraints + for rhs_constraint in rhs_constraints ) else: - assert False, "boolean ops expected" + return ( + self.make_constraint_from_inequality(expr.lhs, expr.rhs, expr.op), + ) elif isinstance(expr, LoopIR.Read): assert len(expr.idx) == 0, "cannot index into boolean" - return ConjunctionConstraint( - Constraint((ConstraintTerm(1, expr.name), ConstraintTerm(-1, ()))), - Constraint((ConstraintTerm(-1, expr.name), ConstraintTerm(1, ()))), + return ( + Constraint((ConstraintTerm(1, (expr.name,)), ConstraintTerm(-1, ()))), ) elif isinstance(expr, LoopIR.Const): if expr.val: return Constraint(()) else: - return Constraint((ConstraintTerm(-1, ()))) + return Constraint((ConstraintTerm(1, ()))) else: assert False, "only boolean expected" - def solve_constraint( - self, constraint: GenericConstraint, bound: int, seed: Optional[int] = None + def make_constraint_from_inequality( + self, lhs: LoopIR.expr, rhs: LoopIR.expr, op: str + ) -> Constraint: + lhs_terms = self.make_constraint_terms(lhs) + rhs_terms = self.make_constraint_terms(rhs) + main_terms = rhs_terms + tuple(term.negate() for term in lhs_terms) + if op == "<": + slack_terms = ( + ConstraintTerm(-1, (Sym("slack"),)), + ConstraintTerm(-1, ()), + ) + elif op == ">": + slack_terms = ( + ConstraintTerm(1, (Sym("slack"),)), + ConstraintTerm(1, ()), + ) + elif op == "<=": + slack_terms = (ConstraintTerm(-1, (Sym("slack"),)),) + elif op == ">=": + slack_terms = (ConstraintTerm(1, (Sym("slack"),)),) + elif op == "==": + slack_terms = () + else: + assert False, "boolean ops expected" + return Constraint(main_terms + slack_terms) + + def solve_constraints( + self, + constraints: tuple[Constraint], + *, + search_limit: int, + seed: Optional[int] = None, ): if seed is not None: np.random.seed(seed=seed) - constraint = ConjunctionConstraint(constraint, self.div_constraint) + all_constraints = constraints + tuple(self.extra_constraints) assignments = {} - syms = constraint.collect_syms() - def solve_recursive() -> bool: - sym_domains = [ - ( - simplify_disjunction( - tuple( - sym_range.intersect( - Range(0, bound) - if sym in self.nonneg_vars - else ( - Range(0, 1) - if sym in self.bool_vars - else Range(-bound, bound) - ) - ) - for sym_range in constraint.apply_assignments( - assignments, sym - ) - ) - ), - sym, - ) - for sym in syms - assignments.keys() - ] - if len(sym_domains) == 0: - return True + def solve_recursive(): + linear_constraints: list[LinearConstraint] = [] + linear_constraint_syms: set[Sym] = set() + nonlinear_syms: set[Sym] = set() + for constraint in all_constraints: + assign_result = constraint.apply_assignments(assignments) + if assign_result is not None: + linear_constraints.append(assign_result) + linear_constraint_syms |= { + sym for sym in assign_result.coefficients.keys() + } + + nonlinear_syms |= constraint.collect_nonlinear_syms() + sym_ordering = {sym: i for i, sym in enumerate(linear_constraint_syms)} + matrix_Ab = np.zeros( + (len(linear_constraints), len(linear_constraint_syms) + 1), + dtype=np.int32, + ) + for row, linear_constraint in enumerate(linear_constraints): + for sym, coefficient in linear_constraint.coefficients: + matrix_Ab[row, sym_ordering[sym]] = coefficient + matrix_Ab[row, len(linear_constraint_syms)] = linear_constraint.offset + + for _ in range(search_limit): + if solve_recursive(): + return assignments else: + assignments = {} - def domain_size(sym_domain: tuple[Range]) -> int: - return sum( - sym_range.upper_bound - sym_range.lower_bound + 1 - for sym_range in sym_domain - ) - - sym_domains.sort(key=lambda sym_domain: domain_size(sym_domain[0])) - sym_domain, sym = sym_domains[0] - if len(sym_domain) == 0: - return False - range_sizes = np.array( - [ - sym_range.upper_bound - sym_range.lower_bound + 1 - for sym_range in sym_domain - ] - ) - chosen_range = np.random.choice( - sym_domain, p=range_sizes / np.sum(range_sizes) - ) - assignments[sym] = np.random.randint( - chosen_range.lower_bound, chosen_range.upper_bound + 1 - ) - return solve_recursive() - - while not solve_recursive(): - assignments = {} - return assignments + return None diff --git a/tests/test_constraint_solver.py b/tests/test_constraint_solver.py index cfe889390..1ac0b61e6 100644 --- a/tests/test_constraint_solver.py +++ b/tests/test_constraint_solver.py @@ -18,7 +18,7 @@ def foo(a: size, b: size, c: size): assert ( golden == ConstraintMaker(foo_type.type_map) - .make_constraint(foo._loopir_proc.preds[0]) + .make_constraints(foo._loopir_proc.preds[0]) .pretty_print() ) @@ -32,11 +32,13 @@ def foo(a: size, b: size, c: size): foo_type = TypeVisitor() foo_type.visit(foo._loopir_proc) cm = ConstraintMaker(foo_type.type_map) - constraint = cm.make_constraint(foo._loopir_proc.preds[0]) + constraint = cm.make_constraints(foo._loopir_proc.preds[0]) assert golden == ", ".join( [ f"{str(sym)} = {val}" - for sym, val in cm.solve_constraint(constraint, 16, 13).items() + for sym, val in cm.solve_constraint( + constraint, bound=16, search_limit=10, seed=13 + ).items() ] ) @@ -50,7 +52,7 @@ def foo(a: size, b: size, c: size): foo_type = TypeVisitor() foo_type.visit(foo._loopir_proc) cm = ConstraintMaker(foo_type.type_map) - constraint = cm.make_constraint(foo._loopir_proc.preds[0]) + constraint = cm.make_constraints(foo._loopir_proc.preds[0]) assert golden == constraint.pretty_print() @@ -63,10 +65,12 @@ def foo(a: size, b: size, c: size): foo_type = TypeVisitor() foo_type.visit(foo._loopir_proc) cm = ConstraintMaker(foo_type.type_map) - constraint = cm.make_constraint(foo._loopir_proc.preds[0]) + constraint = cm.make_constraints(foo._loopir_proc.preds[0]) assert golden == ", ".join( [ f"{str(sym)} = {val}" - for sym, val in cm.solve_constraint(constraint, 16, 13).items() + for sym, val in cm.solve_constraint( + constraint, bound=16, search_limit=10, seed=13 + ).items() ] ) From 45903a2bbe80727dd92679f051c63a2eae1e90e4 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 19 Feb 2025 18:19:12 -0500 Subject: [PATCH 05/24] finish integer constraint solver --- requirements.txt | 2 + src/exo/rewrite/chexo.py | 47 ++--- src/exo/rewrite/constraint_solver.py | 190 ++++++++++++++---- .../test_constraint_solver/test_divmod.txt | 4 +- .../test_divmod_solve.txt | 2 +- .../test_make_constraint.txt | 3 +- .../test_constraint_solver/test_solve.txt | 2 +- tests/test_constraint_solver.py | 30 +-- 8 files changed, 193 insertions(+), 87 deletions(-) diff --git a/requirements.txt b/requirements.txt index 03ed1c373..dbdf8f7a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ asdl==0.1.5 build==1.2.2.post1 z3-solver==4.14.0.0 yapf==0.43.0 +scipy==1.6.2 +hsnf==0.3.16 \ No newline at end of file diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index 37719a883..705f8926f 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -8,9 +8,6 @@ import numpy as np from .new_eff import SchedulingError from .constraint_solver import ( - ConjunctionConstraint, - ConstraintTerm, - GenericConstraint, Constraint, ConstraintMaker, ) @@ -140,40 +137,27 @@ def eval_tensor_dimension(dim_expr, control_values): FLOAT_BOUND = 32 -def collect_path_constraints(cursor, cm: ConstraintMaker) -> GenericConstraint: +def collect_path_constraints(cursor, cm: ConstraintMaker) -> tuple[Constraint]: cur = cursor - result = Constraint(()) + result = [] while cur.depth() != 0: if isinstance(cur._node, LoopIR.For): - result = ConjunctionConstraint( - ConjunctionConstraint( - result, - Constraint( - (ConstraintTerm(1, (cur._node.iter,)),) - + tuple( - term.negate() - for term in cm.make_constraint_terms(cur._node.lo) - ) - ), - ), - Constraint( - (ConstraintTerm(-1, (cur._node.iter,)),) - + cm.make_constraint_terms(cur._node.hi) - + (ConstraintTerm(-1, ()),) - ), + result.append( + cm.make_constraint_from_inequality(cur._node.iter, cur._node.lo, ">=") + ) + result.append( + cm.make_constraint_from_inequality(cur._node.iter, cur._node.hi, "<") ) elif isinstance(cur._node, LoopIR.If): - result = ConjunctionConstraint(result, cm.make_constraints(cur._node.cond)) + result.extend(cm.make_constraints(cur._node.cond)) cur = cur.parent() - return result + return tuple(result) -def generate_args(args, constraint: Constraint, cm: ConstraintMaker): +def generate_args(args, constraint: tuple[Constraint], cm: ConstraintMaker): arg_values = {} control_values = {} - assignments = cm.solve_constraint( - constraint, bound=CONTROL_VAL_BOUND, search_limit=SEARCH_LIMIT - ) + assignments = cm.solve_constraints(constraint, search_limit=SEARCH_LIMIT) if assignments is None: return None for arg in args: @@ -232,15 +216,14 @@ def generate_args(args, constraint: Constraint, cm: ConstraintMaker): def fuzz_reorder_stmts(s1, s2): - print(s1) proc = s1.get_root() proc_type_visitor = TypeVisitor() proc_type_visitor.visit(proc) cm = ConstraintMaker(proc_type_visitor.type_map) - constraint = Constraint(()) + constraints = [] for pred in proc.preds: - constraint = ConjunctionConstraint(constraint, cm.make_constraints(pred)) - constraint = ConjunctionConstraint(constraint, collect_path_constraints(s1, cm)) + constraints.extend(cm.make_constraints(pred)) + constraints.extend(collect_path_constraints(s1, cm)) args = [ LoopIR.fnarg( name=var, @@ -256,7 +239,7 @@ def fuzz_reorder_stmts(s1, s2): arg for arg in args if arg.type.is_numeric() ] for _ in range(TEST_CASE_BOUND): - arg_vals1 = generate_args(args, constraint, cm) + arg_vals1 = generate_args(args, tuple(constraints), cm) if arg_vals1 is None: continue arg_vals2 = { diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index 4509a57ca..9c6cf40c3 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -4,12 +4,14 @@ from exo.core.prelude import Sym from ..core.LoopIR import LoopIR, T import numpy as np +from scipy.optimize import linprog +from hsnf import smith_normal_form @dataclass class ConstraintTerm: coefficient: int - syms: tuple[Sym] + syms: tuple[Sym, ...] def negate(self) -> "ConstraintTerm": return ConstraintTerm(-self.coefficient, self.syms) @@ -53,7 +55,7 @@ class LinearConstraint: @dataclass class Constraint: - terms: tuple[ConstraintTerm] + terms: tuple[ConstraintTerm, ...] def apply_assignments( self, assignments: dict[Sym, int] @@ -96,11 +98,11 @@ def pretty_print(self) -> str: class ConstraintMaker: def __init__(self, type_map: dict[Sym, LoopIR.type]): - self.unconstrained_var_subs: dict[Sym, tuple[ConstraintTerm]] = {} + self.unconstrained_var_subs: dict[Sym, tuple[ConstraintTerm, ...]] = {} self.extra_constraints: list[Constraint] = [] self.stride_dummies: dict[tuple[Sym, int], Sym] = {} for sym, sym_type in type_map.items(): - if isinstance(sym_type, T.Size, T.Stride): + if isinstance(sym_type, (T.Size, T.Stride)): # positive constraint self.extra_constraints.append( Constraint( @@ -111,7 +113,7 @@ def __init__(self, type_map: dict[Sym, LoopIR.type]): ) ) ) - elif isinstance(sym_type, T.Int, T.Num): + elif isinstance(sym_type, (T.Int, T.Num)): # unsigned variables are represented as a - b, where a and b are nonnegative a, b = Sym("a"), Sym("b") self.unconstrained_var_subs[sym] = ( @@ -125,14 +127,18 @@ def __init__(self, type_map: dict[Sym, LoopIR.type]): ( ConstraintTerm(1, (sym,)), ConstraintTerm(-1, ()), - ConstraintTerm(1, (Sym("slack"))), + ConstraintTerm(1, (Sym("slack"),)), ) ) ) - def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: + def make_constraint_terms( + self, expr: Union[LoopIR.expr, Sym] + ) -> tuple[ConstraintTerm, ...]: # expect that expr is int type - if isinstance(expr, LoopIR.Read): + if isinstance(expr, Sym): + return (ConstraintTerm(1, (expr,)),) + elif isinstance(expr, LoopIR.Read): assert ( len(expr.idx) == 0 ), "indexing not supported in assertions (yet, todo)" @@ -162,8 +168,8 @@ def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: div, rem = Sym("div"), Sym("rem") self.extra_constraints.append( Constraint( - lhs_terms - + (ConstraintTerm(1, (rem,))) + tuple(lhs_term.negate() for lhs_term in lhs_terms) + + (ConstraintTerm(1, (rem,)),) + tuple( rhs_term.multiply(ConstraintTerm(1, (div,))) for rhs_term in rhs_terms @@ -175,6 +181,7 @@ def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: ( ConstraintTerm(-1, (rem,)), ConstraintTerm(-1, (Sym("slack"),)), + ConstraintTerm(-1, ()), ) + rhs_terms ) @@ -186,7 +193,6 @@ def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: if (expr.name, expr.dim) not in self.stride_dummies: new_sym = Sym("stride") self.stride_dummies[(expr.name, expr.dim)] = new_sym - self.nonneg_vars.add(new_sym) dummy = self.stride_dummies[(expr.name, expr.dim)] return (ConstraintTerm(1, (dummy,)),) else: @@ -195,7 +201,7 @@ def make_constraint_terms(self, expr: LoopIR.expr) -> tuple[ConstraintTerm]: def make_constraints( self, expr: LoopIR.expr, - ) -> tuple[Constraint]: + ) -> tuple[Constraint, ...]: # expect that expr is bool type if isinstance(expr, LoopIR.BinOp): if expr.op == "and": @@ -227,14 +233,14 @@ def make_constraints( ) elif isinstance(expr, LoopIR.Const): if expr.val: - return Constraint(()) + return (Constraint(()),) else: - return Constraint((ConstraintTerm(1, ()))) + return (Constraint((ConstraintTerm(1, ()))),) else: assert False, "only boolean expected" def make_constraint_from_inequality( - self, lhs: LoopIR.expr, rhs: LoopIR.expr, op: str + self, lhs: Union[LoopIR.expr, Sym], rhs: Union[LoopIR.expr, Sym], op: str ) -> Constraint: lhs_terms = self.make_constraint_terms(lhs) rhs_terms = self.make_constraint_terms(rhs) @@ -261,7 +267,7 @@ def make_constraint_from_inequality( def solve_constraints( self, - constraints: tuple[Constraint], + constraints: tuple[Constraint, ...], *, search_limit: int, seed: Optional[int] = None, @@ -270,32 +276,140 @@ def solve_constraints( np.random.seed(seed=seed) all_constraints = constraints + tuple(self.extra_constraints) assignments = {} + x_bound = 100 + sym_universe = set() + for constraint in all_constraints: + sym_universe |= constraint.collect_syms() - def solve_recursive(): - linear_constraints: list[LinearConstraint] = [] - linear_constraint_syms: set[Sym] = set() - nonlinear_syms: set[Sym] = set() - for constraint in all_constraints: - assign_result = constraint.apply_assignments(assignments) - if assign_result is not None: - linear_constraints.append(assign_result) - linear_constraint_syms |= { - sym for sym in assign_result.coefficients.keys() - } + def solve_helper(): + while len(assignments) < len(sym_universe): + linear_constraints: list[LinearConstraint] = [] + linear_constraint_syms: set[Sym] = set() + nonlinear_syms: set[Sym] = set() + for constraint in all_constraints: + assign_result = constraint.apply_assignments(assignments) + if assign_result is not None: + linear_constraints.append(assign_result) + linear_constraint_syms |= { + sym for sym in assign_result.coefficients.keys() + } - nonlinear_syms |= constraint.collect_nonlinear_syms() - sym_ordering = {sym: i for i, sym in enumerate(linear_constraint_syms)} - matrix_Ab = np.zeros( - (len(linear_constraints), len(linear_constraint_syms) + 1), - dtype=np.int32, - ) - for row, linear_constraint in enumerate(linear_constraints): - for sym, coefficient in linear_constraint.coefficients: - matrix_Ab[row, sym_ordering[sym]] = coefficient - matrix_Ab[row, len(linear_constraint_syms)] = linear_constraint.offset + nonlinear_syms |= constraint.collect_nonlinear_syms() + nonlinear_syms -= assignments.keys() + priority_syms = nonlinear_syms & linear_constraint_syms + if len(priority_syms) == 0 and len(nonlinear_syms) != 0: + chosen_sym = np.random.choice( + sorted(list(nonlinear_syms), key=lambda sym: sym._id) + ) + assignments[chosen_sym] = np.random.randint(0, x_bound) + continue + sym_ordering = { + sym: i + for i, sym in enumerate( + sorted( + list(linear_constraint_syms), + key=lambda sym: sym._id, + ) + ) + } + n = len(linear_constraints) + m = len(linear_constraint_syms) + matrix_A = np.zeros( + (n, m), + dtype=np.int32, + ) + vec_b = np.zeros(n, dtype=np.int32) + for row, linear_constraint in enumerate(linear_constraints): + for sym, coefficient in linear_constraint.coefficients.items(): + matrix_A[row, sym_ordering[sym]] = coefficient + vec_b[row] = -linear_constraint.offset + matrix_B, matrix_U, matrix_V = smith_normal_form(matrix_A) + vec_d = matrix_U @ vec_b + k = min(n, m) + vec_f = np.zeros(m) + for i in range(min(n, m)): + if matrix_B[i, i] == 0: + k = i + break + if vec_d[i] % matrix_B[i, i] != 0: + return False + vec_f += vec_d[i] / matrix_B[i, i] * matrix_V[:, i] + if m == k: + solution = vec_f + if not np.all(vec_f >= 0): + return False + else: + matrix_C = matrix_V[:, k:] + upper_bound_matrix = np.concatenate((matrix_C, -matrix_C), axis=0) + upper_bound_offset = np.concatenate( + (np.ones_like(vec_f) * x_bound - vec_f, vec_f), axis=0 + ) + lp = linprog( + np.zeros(m - k), + A_ub=upper_bound_matrix, + b_ub=upper_bound_offset, + bounds=(None, None), + ) + if not lp.success: + return False + cur_y = lp.x + har_iter = 50 + last_int_y = None + for _ in range(har_iter): + direction = np.random.normal(size=m - k) + direction = direction / np.linalg.norm(direction) + lower_bounds = -matrix_C @ cur_y - vec_f + upper_bounds = lower_bounds + x_bound + coefficients = matrix_C @ direction + lower_bounds = lower_bounds[coefficients != 0] + upper_bounds = upper_bounds[coefficients != 0] + coefficients = coefficients[coefficients != 0] + max_lambda = np.nanmin( + np.where(coefficients < 0, lower_bounds, upper_bounds) + / coefficients + ) + min_lambda = np.nanmax( + np.where(coefficients >= 0, lower_bounds, upper_bounds) + / coefficients + ) + new_y = cur_y + direction * ( + np.random.rand() * (max_lambda - min_lambda) + min_lambda + ) + new_int_y = np.round(new_y) + cur_y = new_y + if np.all(upper_bound_matrix @ new_int_y <= upper_bound_offset): + last_int_y = new_int_y + if last_int_y is not None: + solution = matrix_C @ last_int_y + vec_f + else: + return False + + chosen_sym = None + if len(priority_syms) != 0: + chosen_sym = np.random.choice( + sorted(list(priority_syms), key=lambda sym: sym._id) + ) + elif len(linear_constraint_syms) != 0: + chosen_sym = np.random.choice( + sorted(list(linear_constraint_syms), key=lambda sym: sym._id) + ) + if chosen_sym is None: + free_syms = ( + sym_universe + - linear_constraint_syms + - assignments.keys() + - nonlinear_syms + ) + chosen_sym = np.random.choice( + sorted(list(free_syms), key=lambda sym: sym._id) + ) + assignments[chosen_sym] = np.random.randint(0, x_bound) + else: + assignments[chosen_sym] = int(solution[sym_ordering[chosen_sym]]) + return True for _ in range(search_limit): - if solve_recursive(): + if solve_helper(): return assignments else: assignments = {} diff --git a/tests/golden/test_constraint_solver/test_divmod.txt b/tests/golden/test_constraint_solver/test_divmod.txt index 4887b6f37..4ddb55357 100644 --- a/tests/golden/test_constraint_solver/test_divmod.txt +++ b/tests/golden/test_constraint_solver/test_divmod.txt @@ -1 +1,3 @@ -(((4 * a + 1 * b + -1 * c + -1 >= 0) or (3 + -1 * a >= 0)) and (5 + -1 * b + -1 >= 0)) and ((3 + -1 * rem >= 0) and (1 * rem + -3 >= 0)) \ No newline at end of file +3 * c + -1 * c * a + -1 * c * slack + -12 * a + 4 * a * a + 4 * a * slack + -3 * b + 1 * b * a + 1 * b * slack + 3 * slack + -1 * slack * a + -1 * slack * slack + 3 + -1 * a + -1 * slack == 0 +5 + -1 * b + -1 * slack + -1 == 0 +3 + -1 * rem == 0 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_divmod_solve.txt b/tests/golden/test_constraint_solver/test_divmod_solve.txt index b8d548a89..3eae0fa09 100644 --- a/tests/golden/test_constraint_solver/test_divmod_solve.txt +++ b/tests/golden/test_constraint_solver/test_divmod_solve.txt @@ -1 +1 @@ -rem = 3, b = 4, a = 11, div = 2, c = 11 \ No newline at end of file +a = 31, slack = 1, div = 7, b = 2, slack = 0, c = 48, slack = 47, slack = 2, rem = 3, slack = 30, slack = 84, slack = 77 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_make_constraint.txt b/tests/golden/test_constraint_solver/test_make_constraint.txt index 7291f502e..5da4ee0b6 100644 --- a/tests/golden/test_constraint_solver/test_make_constraint.txt +++ b/tests/golden/test_constraint_solver/test_make_constraint.txt @@ -1 +1,2 @@ -((4 * a + 1 * b + -1 * c + -1 >= 0) or (3 + -1 * a >= 0)) and (5 + -1 * b + -1 >= 0) \ No newline at end of file +3 * c + -1 * c * a + -1 * c * slack + -12 * a + 4 * a * a + 4 * a * slack + -3 * b + 1 * b * a + 1 * b * slack + 3 * slack + -1 * slack * a + -1 * slack * slack + 3 + -1 * a + -1 * slack == 0 +5 + -1 * b + -1 * slack + -1 == 0 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_solve.txt b/tests/golden/test_constraint_solver/test_solve.txt index 8d0755a0f..ea82ef5bc 100644 --- a/tests/golden/test_constraint_solver/test_solve.txt +++ b/tests/golden/test_constraint_solver/test_solve.txt @@ -1 +1 @@ -b = 2, a = 2, c = 12 \ No newline at end of file +a = 70, slack = 96, slack = 0, slack = 69, b = 4, c = 97, slack = 3, slack = 80, slack = 186 \ No newline at end of file diff --git a/tests/test_constraint_solver.py b/tests/test_constraint_solver.py index 1ac0b61e6..e5d9ffb24 100644 --- a/tests/test_constraint_solver.py +++ b/tests/test_constraint_solver.py @@ -15,11 +15,13 @@ def foo(a: size, b: size, c: size): foo_type = TypeVisitor() foo_type.visit(foo._loopir_proc) - assert ( - golden - == ConstraintMaker(foo_type.type_map) - .make_constraints(foo._loopir_proc.preds[0]) - .pretty_print() + assert golden == "\n".join( + [ + constraint.pretty_print() + for constraint in ConstraintMaker(foo_type.type_map).make_constraints( + foo._loopir_proc.preds[0] + ) + ] ) @@ -32,12 +34,12 @@ def foo(a: size, b: size, c: size): foo_type = TypeVisitor() foo_type.visit(foo._loopir_proc) cm = ConstraintMaker(foo_type.type_map) - constraint = cm.make_constraints(foo._loopir_proc.preds[0]) + constraints = cm.make_constraints(foo._loopir_proc.preds[0]) assert golden == ", ".join( [ f"{str(sym)} = {val}" - for sym, val in cm.solve_constraint( - constraint, bound=16, search_limit=10, seed=13 + for sym, val in cm.solve_constraints( + constraints, search_limit=10, seed=13 ).items() ] ) @@ -52,8 +54,10 @@ def foo(a: size, b: size, c: size): foo_type = TypeVisitor() foo_type.visit(foo._loopir_proc) cm = ConstraintMaker(foo_type.type_map) - constraint = cm.make_constraints(foo._loopir_proc.preds[0]) - assert golden == constraint.pretty_print() + constraints = cm.make_constraints(foo._loopir_proc.preds[0]) + assert golden == "\n".join( + [constraint.pretty_print() for constraint in constraints] + ) def test_divmod_solve(golden): @@ -65,12 +69,12 @@ def foo(a: size, b: size, c: size): foo_type = TypeVisitor() foo_type.visit(foo._loopir_proc) cm = ConstraintMaker(foo_type.type_map) - constraint = cm.make_constraints(foo._loopir_proc.preds[0]) + constraints = cm.make_constraints(foo._loopir_proc.preds[0]) assert golden == ", ".join( [ f"{str(sym)} = {val}" - for sym, val in cm.solve_constraint( - constraint, bound=16, search_limit=10, seed=13 + for sym, val in cm.solve_constraints( + constraints, search_limit=10, seed=13 ).items() ] ) From b74c3623e06f27e92798e07c89f1c393e2eee288 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 4 Mar 2025 15:32:59 -0500 Subject: [PATCH 06/24] restructure constraint solver for disjunction + write tests --- src/exo/libs/externs.py | 49 +-- src/exo/platforms/gemmini.py | 8 +- src/exo/rewrite/chexo.py | 321 +++++++++++------ src/exo/rewrite/constraint_solver.py | 328 +++++++++++++----- .../test_chexo/test_arg_size_constraints.txt | 5 + .../golden/test_chexo/test_free_variables.txt | 3 + .../test_get_used_config_fields.txt | 1 + .../test_chexo/test_path_constraints.txt | 9 + tests/golden/test_chexo/test_type_visitor.txt | 9 + .../test_disjunction.txt | 10 + .../test_constraint_solver/test_divmod.txt | 15 +- .../test_divmod_solve.txt | 4 +- .../test_constraint_solver/test_inversion.txt | 38 ++ .../test_large_slack.txt | 1 + .../test_make_constraint.txt | 12 +- .../test_constraint_solver/test_solve.txt | 4 +- tests/test_chexo.py | 103 ++++++ tests/test_constraint_solver.py | 100 +++--- 18 files changed, 752 insertions(+), 268 deletions(-) create mode 100644 tests/golden/test_chexo/test_arg_size_constraints.txt create mode 100644 tests/golden/test_chexo/test_free_variables.txt create mode 100644 tests/golden/test_chexo/test_get_used_config_fields.txt create mode 100644 tests/golden/test_chexo/test_path_constraints.txt create mode 100644 tests/golden/test_chexo/test_type_visitor.txt create mode 100644 tests/golden/test_constraint_solver/test_disjunction.txt create mode 100644 tests/golden/test_constraint_solver/test_inversion.txt create mode 100644 tests/golden/test_constraint_solver/test_large_slack.txt create mode 100644 tests/test_chexo.py diff --git a/src/exo/libs/externs.py b/src/exo/libs/externs.py index 8752ed695..4b125e415 100644 --- a/src/exo/libs/externs.py +++ b/src/exo/libs/externs.py @@ -1,4 +1,5 @@ from exo.core.extern import Extern, _EErr +import numpy as np class _Sin(Extern): @@ -20,8 +21,8 @@ def typecheck(self, args): def globl(self, prim_type): return "#include " - # def interpret(self, args): - # return math.sin(args[0]) + def interpret(self, args): + return np.sin(args[0]) def compile(self, args, prim_type): return f"sin(({prim_type}){args[0]})" @@ -55,11 +56,11 @@ def globl(self, prim_type): ) return s - # def interpret(self, args): - # if args[0] > 0: - # return args[0] - # else: - # return 0 + def interpret(self, args): + if args[0] > 0: + return args[0] + else: + return 0 def compile(self, args, prim_type): return f"_relu_{prim_type}(({prim_type}){args[0]})" @@ -95,15 +96,15 @@ def globl(self, prim_type): ) return s - # def interpret(self, args): - # x = args[0] - # v = args[1] - # y = args[2] - # z = args[3] - # if x < v: - # return y - # else: - # return z + def interpret(self, args): + x = args[0] + v = args[1] + y = args[2] + z = args[3] + if x < v: + return y + else: + return z def compile(self, args, prim_type): return f"_select_{prim_type}(({prim_type}){args[0]}, ({prim_type}){args[1]}, ({prim_type}){args[2]}, ({prim_type}){args[3]})" @@ -131,8 +132,8 @@ def typecheck(self, args): def globl(self, prim_type): return "#include " - # def interpret(self, args): - # return math.expf(args[0]) + def interpret(self, args): + return np.exp(args[0]) def compile(self, args, prim_type): return f"expf(({prim_type})({args[0]}))" @@ -161,8 +162,8 @@ def typecheck(self, args): def globl(self, prim_type): return "#include " - # def interpret(self, args): - # return math.fmaxf(args[0], args[1]) + def interpret(self, args): + return np.nanmax([args[0], args[1]]) def compile(self, args, prim_type): return f"fmaxf(({prim_type})({args[0]}), ({prim_type})({args[1]}))" @@ -195,8 +196,8 @@ def globl(self, prim_type): }} """ - # def interpret(self, args): - # return math.sigmoid(args[0]) + def interpret(self, args): + return 1 / (1 + np.exp(-args[0])) def compile(self, args, prim_type): return f"sigmoid(({prim_type})({args[0]}))" @@ -224,8 +225,8 @@ def typecheck(self, args): def globl(self, prim_type): return "#include " - # def interpret(self, args): - # return math.sqrt(args[0]) + def interpret(self, args): + return np.sqrt(args[0]) def compile(self, args, prim_type): return f"sqrt(({prim_type})({args[0]}))" diff --git a/src/exo/platforms/gemmini.py b/src/exo/platforms/gemmini.py index 5f598817b..500e1a60a 100644 --- a/src/exo/platforms/gemmini.py +++ b/src/exo/platforms/gemmini.py @@ -386,8 +386,8 @@ def ld_i8_block( assert n <= 16 assert m <= 4 assert stride(src, 1) == 1 - assert stride(dst, 0) == 16 - assert stride(dst, 1) == 1 + assert stride(dst, 1) == 16 + assert stride(dst, 2) == 1 for i in seq(0, n): for j in seq(0, m): @@ -481,8 +481,8 @@ def zero_block_id2( ): assert n <= 16 assert m <= 4 - assert stride(dst, 0) == 16 - assert stride(dst, 1) == 1 + assert stride(dst, 1) == 16 + assert stride(dst, 2) == 1 for i in seq(0, n): for j in seq(0, m): diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index 705f8926f..a92b4038d 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -1,15 +1,21 @@ -from typing import Optional +from typing import Optional, Union + +from ..core.configs import Config from ..core.LoopIR import LoopIR, T from dataclasses import dataclass, field from ..core.prelude import Sym, SrcInfo from ..core.memory import DRAM, Memory -from ..backend.LoopIR_interpreter import run_interpreter +from ..backend.LoopIR_interpreter import Interpreter, run_interpreter import numpy as np from .new_eff import SchedulingError from .constraint_solver import ( + TRUE_CONSTRAINT, Constraint, + ConstraintClause, ConstraintMaker, + ConstraintTerm, + DisjointConstraint, ) @@ -60,6 +66,17 @@ def visit(self, node): self.visit_generic(node) +@dataclass +class ConfigVisitor(LoopIRVisitor): + config_reads: dict[tuple[str, str], LoopIR.type] = field(default_factory=lambda: {}) + + def visit(self, node): + if isinstance(node, LoopIR.ReadConfig): + self.config_reads[(node.config.name(), node.field)] = node.type + else: + self.visit_generic(node) + + @dataclass class UsedVariableVisitor(LoopIRVisitor): used_vars: set[Sym] = field(default_factory=lambda: set()) @@ -71,6 +88,13 @@ def visit(self, node): self.visit_generic(node) +def get_used_config_fields(fragment): + config_visitor = ConfigVisitor() + for stmt in fragment: + config_visitor.visit(stmt) + return config_visitor.config_reads + + def get_free_variables(type_map, mem_map, fragment): fragment_type_visitor = TypeVisitor() fragment_var_visitor = UsedVariableVisitor() @@ -86,17 +110,19 @@ def get_free_variables(type_map, mem_map, fragment): } -def eval_tensor_dimension(dim_expr, control_values): +def eval_tensor_dimension( + dim_expr: LoopIR.expr, arg_values: dict[Sym, Union[int, bool, float, np.ndarray]] +) -> int: if isinstance(dim_expr, LoopIR.Read): - return control_values[dim_expr.name] + return arg_values[dim_expr.name] elif isinstance(dim_expr, LoopIR.Const): return dim_expr.val elif isinstance(dim_expr, LoopIR.USub): - return -eval_tensor_dimension(dim_expr.arg, control_values) + return -eval_tensor_dimension(dim_expr.arg, arg_values) elif isinstance(dim_expr, LoopIR.BinOp): lhs, rhs = eval_tensor_dimension( - dim_expr.lhs, control_values - ), eval_tensor_dimension(dim_expr.rhs, control_values) + dim_expr.lhs, arg_values + ), eval_tensor_dimension(dim_expr.rhs, arg_values) if dim_expr.op == "+": return lhs + rhs elif dim_expr.op == "-": @@ -113,117 +139,180 @@ def eval_tensor_dimension(dim_expr, control_values): return lhs / rhs elif dim_expr.op == "%": return lhs % rhs - elif dim_expr.op == "==": - return lhs == rhs - elif dim_expr.op == "<": - return lhs < rhs - elif dim_expr.op == ">": - return lhs > rhs - elif dim_expr.op == "<=": - return lhs <= rhs - elif dim_expr.op == ">=": - return lhs >= rhs - elif dim_expr.op == "and": - return lhs and rhs - elif dim_expr.op == "or": - return lhs or rhs + else: + assert False, "unexpected binop in tensor dimension" else: assert False, "unexpected expression type in tensor dimension" CONTROL_VAL_BOUND = 128 +MIN_BUFFER_SIZE_BOUND = 16**1 +MAX_BUFFER_SIZE_BOUND = 16**6 SEARCH_LIMIT = 10 INT_BOUND = 128 FLOAT_BOUND = 32 -def collect_path_constraints(cursor, cm: ConstraintMaker) -> tuple[Constraint]: +def collect_path_constraints(cursor, cm: ConstraintMaker) -> DisjointConstraint: cur = cursor - result = [] + result = TRUE_CONSTRAINT + last_attr = None while cur.depth() != 0: if isinstance(cur._node, LoopIR.For): - result.append( - cm.make_constraint_from_inequality(cur._node.iter, cur._node.lo, ">=") + result = result.intersect( + cm.make_constraint_from_inequality( + cur._node.iter, cur._node.lo, ">=" + ).lift_to_disjoint_constraint() ) - result.append( - cm.make_constraint_from_inequality(cur._node.iter, cur._node.hi, "<") + result = result.intersect( + cm.make_constraint_from_inequality( + cur._node.iter, cur._node.hi, "<" + ).lift_to_disjoint_constraint() ) elif isinstance(cur._node, LoopIR.If): - result.extend(cm.make_constraints(cur._node.cond)) + constraint = cm.make_constraint(cur._node.cond) + if isinstance(last_attr, tuple) and last_attr[0] == "orelse": + result = result.intersect(constraint.invert()) + else: + result = result.intersect(constraint) + last_attr = cur._path[-1] + cur = cur.parent() - return tuple(result) + return result + + +def collect_arg_size_constraints( + args: list[LoopIR.fnarg], cm: ConstraintMaker, buffer_size_bound: int +) -> DisjointConstraint: + constraint = TRUE_CONSTRAINT + for arg in args: + if arg.type.is_tensor_or_window(): + dim_terms: tuple[ConstraintTerm, ...] = (ConstraintTerm(1, ()),) + for dim_expr in arg.type.shape(): + dim_terms = tuple( + dim_term.multiply(rhs_term) + for dim_term in dim_terms + for rhs_term in cm.make_constraint_terms(dim_expr) + ) + constraint = constraint.intersect( + Constraint( + tuple(term.negate() for term in dim_terms) + + (ConstraintTerm(buffer_size_bound, ()),), + True, + ).lift_to_disjoint_constraint() + ) + return constraint + + +@dataclass +class TestCase: + arg_values: dict[Sym, Union[int, bool, float, np.ndarray]] + ctxt: dict[tuple[str, str], Union[int, bool, float, np.ndarray]] + + +def generate_control_value(var_type: LoopIR.type): + if isinstance(var_type, T.Bool): + return np.random.rand() < 0.5 + elif isinstance(var_type, (T.Size, T.Stride)): + return np.random.randint(1, CONTROL_VAL_BOUND) + elif isinstance(var_type, (T.Int, T.Index)): + return np.random.randint(-CONTROL_VAL_BOUND, CONTROL_VAL_BOUND) + else: + assert False, "not a control type" -def generate_args(args, constraint: tuple[Constraint], cm: ConstraintMaker): +def generate_numeric_value(var_type: LoopIR.type, shape: Optional[tuple[int]]): + if isinstance(var_type, (T.F32, T.Num)): + dtype = np.float32 + elif isinstance(var_type, T.F16): + dtype = np.float16 + elif isinstance(var_type, T.F64): + dtype = np.float64 + elif isinstance(var_type, T.INT8): + dtype = np.int8 + elif isinstance(var_type, T.INT32): + dtype = np.int32 + elif isinstance(var_type, T.UINT8): + dtype = np.uint8 + elif isinstance(var_type, T.UINT16): + dtype = np.uint16 + else: + assert False, "not a numeric type" + + if dtype in [np.int8, np.int32]: + return np.random.randint(-INT_BOUND, INT_BOUND, shape, dtype=dtype) + elif dtype in [np.uint8, np.uint16]: + return np.random.randint(0, INT_BOUND, shape, dtype=dtype) + elif dtype in [np.float16, np.float32, np.float64]: + if shape is None: + return (np.random.rand() * 2 - 1) * FLOAT_BOUND + else: + return ((np.random.rand(*shape) * 2 - 1) * FLOAT_BOUND).astype(dtype) + else: + assert False, "unreachable" + + +def generate_test_case( + args: list[LoopIR.fnarg], + config_fields: dict[tuple[str, str], LoopIR.type], + constraint: DisjointConstraint, + cm: ConstraintMaker, +) -> Optional[TestCase]: + ctxt = {} arg_values = {} - control_values = {} - assignments = cm.solve_constraints(constraint, search_limit=SEARCH_LIMIT) - if assignments is None: + solution = cm.solve_constraint( + constraint, bound=CONTROL_VAL_BOUND, search_limit=SEARCH_LIMIT + ) + if solution is None: return None + for (config_name, field), field_type in config_fields.items(): + if (config_name, field) in solution.ctxt: + ctxt[(config_name, field)] = solution.ctxt[(config_name, field)] + else: + if field_type.is_numeric(): + val = generate_numeric_value(field_type, (1,)) + else: + val = generate_control_value(field_type) + ctxt[(config_name, field)] = val + for arg in args: if not arg.type.is_numeric(): - if arg.name in assignments: - val = assignments[arg.name] - elif isinstance(arg.type, T.Bool): - val = np.random.randint(0, CONTROL_VAL_BOUND) < CONTROL_VAL_BOUND / 2 + if arg.name in solution.var_assignments: + if isinstance(arg.type, T.Bool): + val = solution.var_assignments[arg.name] != 0 + else: + val = solution.var_assignments[arg.name] else: - val = np.random.randint(0, CONTROL_VAL_BOUND) - control_values[arg.name] = val - arg_values[str(arg.name)] = val + val = generate_control_value(arg.type) + arg_values[arg.name] = val for arg in args: if arg.type.is_numeric(): - basetype = arg.type.basetype() - if isinstance(basetype, (T.F32, T.Num)): - dtype = np.float32 - elif isinstance(basetype, T.F16): - dtype = np.float16 - elif isinstance(basetype, T.F64): - dtype = np.float64 - elif isinstance(basetype, T.INT8): - dtype = np.int8 - elif isinstance(basetype, T.INT32): - dtype = np.int32 - elif isinstance(basetype, T.UINT8): - dtype = np.uint8 - elif isinstance(basetype, T.UINT16): - dtype = np.uint16 - if arg.type.is_real_scalar(): shape = (1,) else: shape = tuple( - eval_tensor_dimension(dim_expr, control_values) + eval_tensor_dimension(dim_expr, arg_values) for dim_expr in arg.type.shape() ) - if dtype in [np.int8, np.int32]: - arg_values[str(arg.name)] = np.random.randint( - -INT_BOUND, INT_BOUND, shape, dtype=dtype - ) - elif dtype in [np.uint8, np.uint16]: - arg_values[str(arg.name)] = np.random.randint( - 0, INT_BOUND, shape, dtype=dtype - ) - elif dtype in [np.float16, np.float32, np.float64]: - arg_values[str(arg.name)] = ( - np.random.rand(*shape) * FLOAT_BOUND - ).astype(dtype) + arg_values[arg.name] = generate_numeric_value(arg.type.basetype(), shape) - return arg_values + return TestCase(arg_values, ctxt) -TEST_CASE_BOUND = 10 +TEST_CASE_BOUND = 15 def fuzz_reorder_stmts(s1, s2): proc = s1.get_root() proc_type_visitor = TypeVisitor() proc_type_visitor.visit(proc) + config_fields = get_used_config_fields([s1._node, s2._node]) cm = ConstraintMaker(proc_type_visitor.type_map) - constraints = [] + constraint = TRUE_CONSTRAINT for pred in proc.preds: - constraints.extend(cm.make_constraints(pred)) - constraints.extend(collect_path_constraints(s1, cm)) + constraint = constraint.intersect(cm.make_constraint(pred)) + constraint = constraint.intersect(collect_path_constraints(s1, cm)) args = [ LoopIR.fnarg( name=var, @@ -238,37 +327,73 @@ def fuzz_reorder_stmts(s1, s2): args = [arg for arg in args if not arg.type.is_numeric()] + [ arg for arg in args if arg.type.is_numeric() ] + buffer_size_bound = MIN_BUFFER_SIZE_BOUND + print("start") + print(constraint.pretty_print()) + print("end") for _ in range(TEST_CASE_BOUND): - arg_vals1 = generate_args(args, tuple(constraints), cm) - if arg_vals1 is None: + test_case = generate_test_case( + args, + config_fields, + ( + constraint + if buffer_size_bound is None + else constraint.intersect( + collect_arg_size_constraints(args, cm, buffer_size_bound) + ) + ), + cm, + ) + if test_case is None: + if buffer_size_bound is None or buffer_size_bound >= MAX_BUFFER_SIZE_BOUND: + if buffer_size_bound is None: + print(constraint.pretty_print()) + assert buffer_size_bound is not None + buffer_size_bound = None + else: + buffer_size_bound = min(MAX_BUFFER_SIZE_BOUND, buffer_size_bound * 4) continue + arg_vals1 = test_case.arg_values arg_vals2 = { key: val.copy() if isinstance(val, np.ndarray) else val for key, val in arg_vals1.items() } + ctxt1 = test_case.ctxt + ctxt2 = { + key: val.copy() if isinstance(val, np.ndarray) else val + for key, val in ctxt1.items() + } - run_interpreter( - LoopIR.proc( - name=proc.name, - args=args, - preds=[], - body=[s1._node, s2._node], - instr=None, - srcinfo=proc.srcinfo, - ), - arg_vals1, - ) - run_interpreter( - LoopIR.proc( - name=proc.name, - args=args, - preds=[], - body=[s2._node, s1._node], - instr=None, - srcinfo=proc.srcinfo, - ), - arg_vals2, - ) + try: + interpret1 = Interpreter( + LoopIR.proc( + name=proc.name, + args=args, + preds=[], + body=[s1._node, s2._node], + instr=None, + srcinfo=proc.srcinfo, + ), + arg_vals1, + ctxt1, + ) + interpret2 = Interpreter( + LoopIR.proc( + name=proc.name, + args=args, + preds=[], + body=[s2._node, s1._node], + instr=None, + srcinfo=proc.srcinfo, + ), + arg_vals2, + ctxt2, + ) + except Exception as e: + print(e) for x in arg_vals1: if not np.allclose(arg_vals1[x], arg_vals2[x]): raise SchedulingError("mismatch found") + for key, val in interpret1.ctxt.items(): + if key not in interpret2.ctxt or interpret2.ctxt[key] != val: + raise SchedulingError("context mismatch found") diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index 9c6cf40c3..38cd3e6b0 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -1,11 +1,12 @@ from dataclasses import dataclass, field -from typing import Union, Optional +from typing import Literal, Union, Optional from exo.core.prelude import Sym from ..core.LoopIR import LoopIR, T import numpy as np from scipy.optimize import linprog from hsnf import smith_normal_form +import textwrap @dataclass @@ -51,11 +52,13 @@ def collect_nonlinear_syms(self) -> frozenset[Sym]: class LinearConstraint: coefficients: dict[Sym, int] offset: int + has_slack: bool @dataclass class Constraint: terms: tuple[ConstraintTerm, ...] + has_slack: bool def apply_assignments( self, assignments: dict[Sym, int] @@ -74,7 +77,7 @@ def apply_assignments( if sym not in coefficients: coefficients[sym] = 0 coefficients[sym] += coefficient - return LinearConstraint(coefficients, offset) + return LinearConstraint(coefficients, offset, self.has_slack) def collect_syms(self) -> frozenset[Sym]: return frozenset(sym for term in self.terms for sym in term.syms) @@ -84,6 +87,33 @@ def collect_nonlinear_syms(self) -> frozenset[Sym]: *[term.collect_nonlinear_syms() for term in self.terms] ) + def lift_to_disjoint_constraint(self) -> "DisjointConstraint": + return DisjointConstraint((ConstraintClause((self,)),)) + + def invert(self) -> "DisjointConstraint": + if self.has_slack: + return Constraint( + tuple(term.negate() for term in self.terms) + (ConstraintTerm(-1, ()),), + True, + ).lift_to_disjoint_constraint() + else: + return DisjointConstraint( + ( + ConstraintClause( + (Constraint(self.terms + (ConstraintTerm(-1, ()),), True),) + ), + ConstraintClause( + ( + Constraint( + tuple(term.negate() for term in self.terms) + + (ConstraintTerm(-1, ()),), + True, + ), + ) + ), + ) + ) + def pretty_print(self) -> str: return ( " + ".join( @@ -92,60 +122,143 @@ def pretty_print(self) -> str: for term in self.terms ] ) - + " == 0" + + f" {'>=' if self.has_slack else '=='} 0" + ) + + +@dataclass +class ConstraintClause: + constraints: tuple[Constraint, ...] + + def invert(self) -> "DisjointConstraint": + acc = FALSE_CONSTRAINT + for constraint in self.constraints: + acc = acc.union(constraint.invert()) + return acc + + def pretty_print(self) -> str: + lines = [ + "intersect(", + *list( + textwrap.indent(constraint.pretty_print(), "\t") + "," + for constraint in self.constraints + ), + ")", + ] + return "\n".join(lines) + + +@dataclass +class DisjointConstraint: + clauses: tuple[ConstraintClause, ...] + + def intersect(self, other: "DisjointConstraint"): + return DisjointConstraint( + tuple( + ConstraintClause(lhs_clause.constraints + rhs_clause.constraints) + for lhs_clause in self.clauses + for rhs_clause in other.clauses + ) ) + def union(self, other: "DisjointConstraint"): + return DisjointConstraint(self.clauses + other.clauses) + + def invert(self) -> "DisjointConstraint": + acc = TRUE_CONSTRAINT + for clause in self.clauses: + acc = acc.intersect(clause.invert()) + return acc + + def pretty_print(self) -> str: + lines = [ + "union(", + *list( + textwrap.indent(clause.pretty_print(), "\t") + "," + for clause in self.clauses + ), + ")", + ] + return "\n".join(lines) + + +TRUE_CONSTRAINT = DisjointConstraint((ConstraintClause(()),)) +FALSE_CONSTRAINT = DisjointConstraint(()) + + +@dataclass +class Expression: + terms: tuple[ConstraintTerm, ...] + + def apply_assignments(self, assignments: dict[Sym, int]) -> Optional[int]: + result = 0 + for term in self.terms: + assign_result = term.apply_assignments(assignments) + if assign_result is None: + return None + else: + coeff, target = assign_result + if target is None: + result += coeff + else: + return None + return result + + +@dataclass +class Solution: + ctxt: dict[tuple[str, str], int] + var_assignments: dict[Sym, int] + class ConstraintMaker: def __init__(self, type_map: dict[Sym, LoopIR.type]): - self.unconstrained_var_subs: dict[Sym, tuple[ConstraintTerm, ...]] = {} + self.var_subs: dict[Sym, Expression] = {} + self.ctxt: dict[tuple[str, str], Expression] = {} self.extra_constraints: list[Constraint] = [] self.stride_dummies: dict[tuple[Sym, int], Sym] = {} for sym, sym_type in type_map.items(): - if isinstance(sym_type, (T.Size, T.Stride)): - # positive constraint - self.extra_constraints.append( - Constraint( - ( - ConstraintTerm(1, (sym,)), - ConstraintTerm(-1, ()), - ConstraintTerm(-1, (Sym("slack"),)), - ) - ) - ) - elif isinstance(sym_type, (T.Int, T.Num)): - # unsigned variables are represented as a - b, where a and b are nonnegative - a, b = Sym("a"), Sym("b") - self.unconstrained_var_subs[sym] = ( - ConstraintTerm(1, (a,)), - ConstraintTerm(-1, (b,)), - ) - elif isinstance(sym_type, T.Bool): - # constrained to [0, 1] - self.extra_constraints.append( - Constraint( - ( - ConstraintTerm(1, (sym,)), - ConstraintTerm(-1, ()), - ConstraintTerm(1, (Sym("slack"),)), - ) - ) + var_sub_result = self.make_var_sub(sym.name(), sym_type) + if var_sub_result is not None: + self.var_subs[sym] = var_sub_result + + def make_var_sub(self, name: str, var_type: LoopIR.type) -> Optional[Expression]: + if isinstance(var_type, (T.Size, T.Stride)): + # positive variable + return Expression( + (ConstraintTerm(1, (Sym(f"{name}_m1"),)), ConstraintTerm(1, ())) + ) + elif isinstance(var_type, (T.Int, T.Index)): + # unsigned variables are represented as a - b, where a and b are nonnegative + a, b = Sym(f"{name}_a"), Sym(f"{name}_b") + return Expression((ConstraintTerm(1, (a,)), ConstraintTerm(-1, (b,)))) + elif isinstance(var_type, T.Bool): + # constrained to [0, 1] + sym = Sym(name) + self.extra_constraints.append( + Constraint( + ( + ConstraintTerm(-1, (sym,)), + ConstraintTerm(1, ()), + ), + True, ) + ) + return Expression((ConstraintTerm(1, (sym,)),)) + else: + return None def make_constraint_terms( self, expr: Union[LoopIR.expr, Sym] ) -> tuple[ConstraintTerm, ...]: # expect that expr is int type if isinstance(expr, Sym): - return (ConstraintTerm(1, (expr,)),) + return self.var_subs[expr].terms elif isinstance(expr, LoopIR.Read): assert ( len(expr.idx) == 0 ), "indexing not supported in assertions (yet, todo)" - if expr.name in self.unconstrained_var_subs: - return self.unconstrained_var_subs[expr.name] - else: - return (ConstraintTerm(1, (expr.name,)),) + return self.var_subs[expr.name].terms elif isinstance(expr, LoopIR.Const): return (ConstraintTerm(expr.val, ()),) elif isinstance(expr, LoopIR.USub): @@ -173,17 +286,18 @@ def make_constraint_terms( + tuple( rhs_term.multiply(ConstraintTerm(1, (div,))) for rhs_term in rhs_terms - ) + ), + False, ) ) self.extra_constraints.append( Constraint( ( ConstraintTerm(-1, (rem,)), - ConstraintTerm(-1, (Sym("slack"),)), ConstraintTerm(-1, ()), ) - + rhs_terms + + rhs_terms, + True, ) ) return (ConstraintTerm(1, (rem if expr.op == "%" else div,)),) @@ -195,47 +309,51 @@ def make_constraint_terms( self.stride_dummies[(expr.name, expr.dim)] = new_sym dummy = self.stride_dummies[(expr.name, expr.dim)] return (ConstraintTerm(1, (dummy,)),) + elif isinstance(expr, LoopIR.ReadConfig): + if (expr.config.name(), expr.field) not in self.ctxt: + field_type = expr.config.lookup_type(expr.field) + var_sub_result = self.make_var_sub( + f"{expr.config.name()}_{expr.field}", field_type + ) + assert ( + var_sub_result is not None + ), "constraints can only occur on control variables" + self.ctxt[(expr.config.name(), expr.field)] = var_sub_result + return self.ctxt[(expr.config.name(), expr.field)].terms else: assert False, f"unsupported expr" - def make_constraints( + def make_constraint( self, expr: LoopIR.expr, - ) -> tuple[Constraint, ...]: + ) -> DisjointConstraint: # expect that expr is bool type if isinstance(expr, LoopIR.BinOp): if expr.op == "and": - return self.make_constraints(expr.lhs) + self.make_constraints(expr.rhs) + lhs_constraints, rhs_constraints = self.make_constraint( + expr.lhs + ), self.make_constraint(expr.rhs) + return lhs_constraints.intersect(rhs_constraints) elif expr.op == "or": - # disjunction multiplies all constraints - lhs_constraints, rhs_constraints = self.make_constraints( + lhs_constraints, rhs_constraints = self.make_constraint( expr.lhs - ), self.make_constraints(expr.rhs) - return tuple( - Constraint( - tuple( - lhs_term.multiply(rhs_term) - for lhs_term in lhs_constraint.terms - for rhs_term in rhs_constraint.terms - ) - ) - for lhs_constraint in lhs_constraints - for rhs_constraint in rhs_constraints - ) + ), self.make_constraint(expr.rhs) + return lhs_constraints.union(rhs_constraints) else: - return ( - self.make_constraint_from_inequality(expr.lhs, expr.rhs, expr.op), - ) + return self.make_constraint_from_inequality( + expr.lhs, expr.rhs, expr.op + ).lift_to_disjoint_constraint() elif isinstance(expr, LoopIR.Read): assert len(expr.idx) == 0, "cannot index into boolean" - return ( - Constraint((ConstraintTerm(1, (expr.name,)), ConstraintTerm(-1, ()))), - ) + return Constraint( + ( + ConstraintTerm(1, (expr.name,)), + ConstraintTerm(-1, ()), + ), + True, + ).lift_to_disjoint_constraint() elif isinstance(expr, LoopIR.Const): - if expr.val: - return (Constraint(()),) - else: - return (Constraint((ConstraintTerm(1, ()))),) + return TRUE_CONSTRAINT if expr.val else FALSE_CONSTRAINT else: assert False, "only boolean expected" @@ -244,39 +362,46 @@ def make_constraint_from_inequality( ) -> Constraint: lhs_terms = self.make_constraint_terms(lhs) rhs_terms = self.make_constraint_terms(rhs) - main_terms = rhs_terms + tuple(term.negate() for term in lhs_terms) + has_slack = True if op == "<": - slack_terms = ( - ConstraintTerm(-1, (Sym("slack"),)), - ConstraintTerm(-1, ()), + terms = ( + rhs_terms + + tuple(term.negate() for term in lhs_terms) + + (ConstraintTerm(-1, ()),) ) elif op == ">": - slack_terms = ( - ConstraintTerm(1, (Sym("slack"),)), - ConstraintTerm(1, ()), + terms = ( + lhs_terms + + tuple(term.negate() for term in rhs_terms) + + (ConstraintTerm(-1, ()),) ) elif op == "<=": - slack_terms = (ConstraintTerm(-1, (Sym("slack"),)),) + terms = rhs_terms + tuple(term.negate() for term in lhs_terms) elif op == ">=": - slack_terms = (ConstraintTerm(1, (Sym("slack"),)),) + terms = lhs_terms + tuple(term.negate() for term in rhs_terms) elif op == "==": - slack_terms = () + has_slack = False + terms = rhs_terms + tuple(term.negate() for term in lhs_terms) else: assert False, "boolean ops expected" - return Constraint(main_terms + slack_terms) + return Constraint(terms, has_slack) - def solve_constraints( + def solve_constraint( self, - constraints: tuple[Constraint, ...], + disjoint_constraint: DisjointConstraint, *, + bound: int, search_limit: int, seed: Optional[int] = None, - ): + ) -> Optional[Solution]: if seed is not None: np.random.seed(seed=seed) - all_constraints = constraints + tuple(self.extra_constraints) + if len(disjoint_constraint.clauses) == 0: + return None + chosen_clause = np.random.choice(list(disjoint_constraint.clauses)) + assert isinstance(chosen_clause, ConstraintClause) + all_constraints = chosen_clause.constraints + tuple(self.extra_constraints) assignments = {} - x_bound = 100 sym_universe = set() for constraint in all_constraints: sym_universe |= constraint.collect_syms() @@ -293,7 +418,6 @@ def solve_helper(): linear_constraint_syms |= { sym for sym in assign_result.coefficients.keys() } - nonlinear_syms |= constraint.collect_nonlinear_syms() nonlinear_syms -= assignments.keys() priority_syms = nonlinear_syms & linear_constraint_syms @@ -301,7 +425,7 @@ def solve_helper(): chosen_sym = np.random.choice( sorted(list(nonlinear_syms), key=lambda sym: sym._id) ) - assignments[chosen_sym] = np.random.randint(0, x_bound) + assignments[chosen_sym] = np.random.randint(0, bound) continue sym_ordering = { sym: i @@ -313,15 +437,21 @@ def solve_helper(): ) } n = len(linear_constraints) - m = len(linear_constraint_syms) + m_nonslack = len(linear_constraint_syms) matrix_A = np.zeros( - (n, m), + (n, m_nonslack), dtype=np.int32, ) + m = m_nonslack vec_b = np.zeros(n, dtype=np.int32) for row, linear_constraint in enumerate(linear_constraints): for sym, coefficient in linear_constraint.coefficients.items(): matrix_A[row, sym_ordering[sym]] = coefficient + if linear_constraint.has_slack: + slack_col = np.zeros((n, 1), dtype=np.int32) + slack_col[row, 0] = -1 + matrix_A = np.hstack((matrix_A, slack_col)) + m += 1 vec_b[row] = -linear_constraint.offset matrix_B, matrix_U, matrix_V = smith_normal_form(matrix_A) vec_d = matrix_U @ vec_b @@ -340,9 +470,12 @@ def solve_helper(): return False else: matrix_C = matrix_V[:, k:] - upper_bound_matrix = np.concatenate((matrix_C, -matrix_C), axis=0) + upper_bound_matrix = np.concatenate( + (matrix_C[:m_nonslack, :], -matrix_C), axis=0 + ) upper_bound_offset = np.concatenate( - (np.ones_like(vec_f) * x_bound - vec_f, vec_f), axis=0 + (np.ones(m_nonslack) * bound - vec_f[:m_nonslack], vec_f), + axis=0, ) lp = linprog( np.zeros(m - k), @@ -359,7 +492,8 @@ def solve_helper(): direction = np.random.normal(size=m - k) direction = direction / np.linalg.norm(direction) lower_bounds = -matrix_C @ cur_y - vec_f - upper_bounds = lower_bounds + x_bound + upper_bounds = lower_bounds + bound + upper_bounds[m_nonslack:] = -np.nan coefficients = matrix_C @ direction lower_bounds = lower_bounds[coefficients != 0] upper_bounds = upper_bounds[coefficients != 0] @@ -403,14 +537,24 @@ def solve_helper(): chosen_sym = np.random.choice( sorted(list(free_syms), key=lambda sym: sym._id) ) - assignments[chosen_sym] = np.random.randint(0, x_bound) + assignments[chosen_sym] = np.random.randint(0, bound) else: assignments[chosen_sym] = int(solution[sym_ordering[chosen_sym]]) return True for _ in range(search_limit): if solve_helper(): - return assignments + var_assignments = {} + for sym, sub in self.var_subs.items(): + result = sub.apply_assignments(assignments) + if result is not None: + var_assignments[sym] = result + ctxt = {} + for (config_name, field), sub in self.ctxt.items(): + result = sub.apply_assignments(assignments) + if result is not None: + ctxt[(config_name, field)] = result + return Solution(ctxt, var_assignments) else: assignments = {} diff --git a/tests/golden/test_chexo/test_arg_size_constraints.txt b/tests/golden/test_chexo/test_arg_size_constraints.txt new file mode 100644 index 000000000..84ae7cad0 --- /dev/null +++ b/tests/golden/test_chexo/test_arg_size_constraints.txt @@ -0,0 +1,5 @@ +union( + intersect( + -2 * a_m1 * b_m1 + -2 * a_m1 + -2 * a_m1 + -2 * b_m1 + -2 + -2 + 256 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_chexo/test_free_variables.txt b/tests/golden/test_chexo/test_free_variables.txt new file mode 100644 index 000000000..eac397ee3 --- /dev/null +++ b/tests/golden/test_chexo/test_free_variables.txt @@ -0,0 +1,3 @@ +a: (size, ) +b: (f32[a], ) +i: (index, None) \ No newline at end of file diff --git a/tests/golden/test_chexo/test_get_used_config_fields.txt b/tests/golden/test_chexo/test_get_used_config_fields.txt new file mode 100644 index 000000000..b50f0ce3c --- /dev/null +++ b/tests/golden/test_chexo/test_get_used_config_fields.txt @@ -0,0 +1 @@ +(TestConfig, a): f32 \ No newline at end of file diff --git a/tests/golden/test_chexo/test_path_constraints.txt b/tests/golden/test_chexo/test_path_constraints.txt new file mode 100644 index 000000000..5d6eee9b3 --- /dev/null +++ b/tests/golden/test_chexo/test_path_constraints.txt @@ -0,0 +1,9 @@ +union( + intersect( + 1 * j_a + -1 * j_b + 0 >= 0, + 1 * a_m1 + 1 + -1 * j_a + 1 * j_b + -1 >= 0, + -1 * a_m1 + -1 + 2 * i_a + -2 * i_b + 1 + -1 >= 0, + 1 * i_a + -1 * i_b + 0 >= 0, + 1 * a_m1 + 1 + -1 * i_a + 1 * i_b + -1 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_chexo/test_type_visitor.txt b/tests/golden/test_chexo/test_type_visitor.txt new file mode 100644 index 000000000..bfa6588c7 --- /dev/null +++ b/tests/golden/test_chexo/test_type_visitor.txt @@ -0,0 +1,9 @@ +Types: +a: size +b: f32[a] +c: f32 +i: index +Mems: +a: +b: +c: \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_disjunction.txt b/tests/golden/test_constraint_solver/test_disjunction.txt new file mode 100644 index 000000000..4d14fa6b7 --- /dev/null +++ b/tests/golden/test_constraint_solver/test_disjunction.txt @@ -0,0 +1,10 @@ +union( + intersect( + 3 + -1 * a_m1 + -1 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), + intersect( + 4 + -1 * b_m1 + -1 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_divmod.txt b/tests/golden/test_constraint_solver/test_divmod.txt index 4ddb55357..336c90602 100644 --- a/tests/golden/test_constraint_solver/test_divmod.txt +++ b/tests/golden/test_constraint_solver/test_divmod.txt @@ -1,3 +1,12 @@ -3 * c + -1 * c * a + -1 * c * slack + -12 * a + 4 * a * a + 4 * a * slack + -3 * b + 1 * b * a + 1 * b * slack + 3 * slack + -1 * slack * a + -1 * slack * slack + 3 + -1 * a + -1 * slack == 0 -5 + -1 * b + -1 * slack + -1 == 0 -3 + -1 * rem == 0 \ No newline at end of file +union( + intersect( + 4 * a_m1 + 4 + 1 * b_m1 + 1 + -1 * c_m1 + -1 + -1 >= 0, + 5 + -1 * b_m1 + -1 + -1 >= 0, + 3 + -1 * rem == 0, + ), + intersect( + 3 + -1 * a_m1 + -1 >= 0, + 5 + -1 * b_m1 + -1 + -1 >= 0, + 3 + -1 * rem == 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_divmod_solve.txt b/tests/golden/test_constraint_solver/test_divmod_solve.txt index 3eae0fa09..f94895ebf 100644 --- a/tests/golden/test_constraint_solver/test_divmod_solve.txt +++ b/tests/golden/test_constraint_solver/test_divmod_solve.txt @@ -1 +1,3 @@ -a = 31, slack = 1, div = 7, b = 2, slack = 0, c = 48, slack = 47, slack = 2, rem = 3, slack = 30, slack = 84, slack = 77 \ No newline at end of file +a = 15 +b = 4 +c = 15 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_inversion.txt b/tests/golden/test_constraint_solver/test_inversion.txt new file mode 100644 index 000000000..2874affa4 --- /dev/null +++ b/tests/golden/test_constraint_solver/test_inversion.txt @@ -0,0 +1,38 @@ +union( + intersect( + -3 + 1 * a_m1 + 1 + -1 >= 0, + -4 + 1 * b_m1 + 1 + -1 >= 0, + ), + intersect( + -3 + 1 * a_m1 + 1 + -1 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), + intersect( + -3 + 1 * a_m1 + 1 + -1 >= 0, + -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, + ), + intersect( + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + -4 + 1 * b_m1 + 1 + -1 >= 0, + ), + intersect( + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), + intersect( + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, + ), + intersect( + -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, + -4 + 1 * b_m1 + 1 + -1 >= 0, + ), + intersect( + -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), + intersect( + -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, + -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_large_slack.txt b/tests/golden/test_constraint_solver/test_large_slack.txt new file mode 100644 index 000000000..7aa6ef31c --- /dev/null +++ b/tests/golden/test_constraint_solver/test_large_slack.txt @@ -0,0 +1 @@ +a = 79 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_make_constraint.txt b/tests/golden/test_constraint_solver/test_make_constraint.txt index 5da4ee0b6..8e1b28b4e 100644 --- a/tests/golden/test_constraint_solver/test_make_constraint.txt +++ b/tests/golden/test_constraint_solver/test_make_constraint.txt @@ -1,2 +1,10 @@ -3 * c + -1 * c * a + -1 * c * slack + -12 * a + 4 * a * a + 4 * a * slack + -3 * b + 1 * b * a + 1 * b * slack + 3 * slack + -1 * slack * a + -1 * slack * slack + 3 + -1 * a + -1 * slack == 0 -5 + -1 * b + -1 * slack + -1 == 0 \ No newline at end of file +union( + intersect( + 4 * a_m1 + 4 + 1 * b_m1 + 1 + -1 * c_m1 + -1 + -1 >= 0, + 5 + -1 * b_m1 + -1 + -1 >= 0, + ), + intersect( + 3 + -1 * a_m1 + -1 >= 0, + 5 + -1 * b_m1 + -1 + -1 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_solve.txt b/tests/golden/test_constraint_solver/test_solve.txt index ea82ef5bc..fabab028a 100644 --- a/tests/golden/test_constraint_solver/test_solve.txt +++ b/tests/golden/test_constraint_solver/test_solve.txt @@ -1 +1,3 @@ -a = 70, slack = 96, slack = 0, slack = 69, b = 4, c = 97, slack = 3, slack = 80, slack = 186 \ No newline at end of file +a = 22 +b = 2 +c = 16 \ No newline at end of file diff --git a/tests/test_chexo.py b/tests/test_chexo.py new file mode 100644 index 000000000..6221fd160 --- /dev/null +++ b/tests/test_chexo.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from exo.core.prelude import Sym + +from exo.rewrite.chexo import ( + TypeVisitor, + get_used_config_fields, + get_free_variables, + collect_path_constraints, + collect_arg_size_constraints, +) +from exo.rewrite.constraint_solver import ConstraintMaker +from exo import proc, config +from exo.core.memory import StaticMemory + + +def stringify_dict(d): + def check_tuple(x): + if isinstance(x, tuple): + return f"({', '.join(str(x_item) for x_item in x)})" + else: + return str(x) + + return "\n".join( + sorted(f"{check_tuple(k)}: {check_tuple(v)}" for k, v in d.items()) + ) + + +def test_type_visitor(golden): + @proc + def foo(a: size, b: f32[a]): + for i in seq(0, a): + c: f32 @ StaticMemory + c = b[i] * 2 + + type_visitor = TypeVisitor() + type_visitor.visit(foo._loopir_proc) + types = stringify_dict(type_visitor.type_map) + mems = stringify_dict(type_visitor.mem_map) + assert golden == f"Types:\n{types}\nMems:\n{mems}" + + +def test_get_used_config_fields(golden): + @config + class TestConfig: + a: f32 + b: size + c: f32 + + @proc + def foo(a: f32): + TestConfig.c = a + a = TestConfig.a + + used_configs = get_used_config_fields(foo._loopir_proc.body) + assert golden == stringify_dict(used_configs) + + +def test_free_variables(golden): + @proc + def foo(a: size, b: f32[a]): + for i in seq(0, a): + c: f32 @ StaticMemory + c = b[i] * 2 + + type_visitor = TypeVisitor() + type_visitor.visit(foo._loopir_proc) + free_vars = get_free_variables( + type_visitor.type_map, + type_visitor.mem_map, + [cursor._impl._node for cursor in foo.find("c: _").as_block().expand()], + ) + assert golden == stringify_dict(free_vars) + + +def test_path_constraints(golden): + @proc + def foo(a: size, b: f32[a]): + for i in seq(0, a): + if 2 * i < a: + b[i] = 0 + else: + for j in seq(0, a): + b[j] = b[i] + + type_visitor = TypeVisitor() + type_visitor.visit(foo._loopir_proc) + cm = ConstraintMaker(type_visitor.type_map) + assert ( + golden + == collect_path_constraints(foo.find("b[j] = b[i]")._impl, cm).pretty_print() + ) + + +def test_arg_size_constraints(golden): + @proc + def foo(a: size, b: size, c: f32[a * 2, b + 1]): + pass + + type_visitor = TypeVisitor() + type_visitor.visit(foo._loopir_proc) + cm = ConstraintMaker(type_visitor.type_map) + constraints = collect_arg_size_constraints(foo._loopir_proc.args, cm) + assert golden == constraints.pretty_print() diff --git a/tests/test_constraint_solver.py b/tests/test_constraint_solver.py index e5d9ffb24..fd0ae7906 100644 --- a/tests/test_constraint_solver.py +++ b/tests/test_constraint_solver.py @@ -1,28 +1,43 @@ from __future__ import annotations from exo.core.prelude import Sym -from exo.rewrite.constraint_solver import ConstraintMaker -from exo.core.LoopIR import LoopIR +from exo.rewrite.constraint_solver import ConstraintMaker, DisjointConstraint +from exo.core.LoopIR import T from exo import proc from exo.rewrite.chexo import TypeVisitor +def stringify_proc_constraint(p, invert=False): + p_type = TypeVisitor() + p_type.visit(p._loopir_proc) + constraint = ConstraintMaker(p_type.type_map).make_constraint( + p._loopir_proc.preds[0] + ) + return (constraint.invert() if invert else constraint).pretty_print() + + +def solve_proc_assertion(p): + p_type = TypeVisitor() + p_type.visit(p._loopir_proc) + cm = ConstraintMaker(p_type.type_map) + constraint = cm.make_constraint(p._loopir_proc.preds[0]) + return "\n".join( + [ + f"{str(sym)} = {val}" + for sym, val in cm.solve_constraint( + constraint, bound=100, search_limit=10, seed=13 + ).var_assignments.items() + ] + ) + + def test_make_constraint(golden): @proc def foo(a: size, b: size, c: size): assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) pass - foo_type = TypeVisitor() - foo_type.visit(foo._loopir_proc) - assert golden == "\n".join( - [ - constraint.pretty_print() - for constraint in ConstraintMaker(foo_type.type_map).make_constraints( - foo._loopir_proc.preds[0] - ) - ] - ) + assert golden == stringify_proc_constraint(foo) def test_solve(golden): @@ -31,18 +46,7 @@ def foo(a: size, b: size, c: size): assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) pass - foo_type = TypeVisitor() - foo_type.visit(foo._loopir_proc) - cm = ConstraintMaker(foo_type.type_map) - constraints = cm.make_constraints(foo._loopir_proc.preds[0]) - assert golden == ", ".join( - [ - f"{str(sym)} = {val}" - for sym, val in cm.solve_constraints( - constraints, search_limit=10, seed=13 - ).items() - ] - ) + assert golden == solve_proc_assertion(foo) def test_divmod(golden): @@ -51,13 +55,7 @@ def foo(a: size, b: size, c: size): assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) and (a % 4 == 3) pass - foo_type = TypeVisitor() - foo_type.visit(foo._loopir_proc) - cm = ConstraintMaker(foo_type.type_map) - constraints = cm.make_constraints(foo._loopir_proc.preds[0]) - assert golden == "\n".join( - [constraint.pretty_print() for constraint in constraints] - ) + assert golden == stringify_proc_constraint(foo) def test_divmod_solve(golden): @@ -66,15 +64,31 @@ def foo(a: size, b: size, c: size): assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) and (a % 4 == 3) pass - foo_type = TypeVisitor() - foo_type.visit(foo._loopir_proc) - cm = ConstraintMaker(foo_type.type_map) - constraints = cm.make_constraints(foo._loopir_proc.preds[0]) - assert golden == ", ".join( - [ - f"{str(sym)} = {val}" - for sym, val in cm.solve_constraints( - constraints, search_limit=10, seed=13 - ).items() - ] - ) + assert golden == solve_proc_assertion(foo) + + +def test_large_slack(golden): + @proc + def foo(a: size): + assert a <= 1000000 + pass + + assert golden == solve_proc_assertion(foo) + + +def test_disjunction(golden): + @proc + def foo(a: size, b: size): + assert (a <= 3 or b <= 4) and (a + b < 4) + pass + + assert golden == stringify_proc_constraint(foo) + + +def test_inversion(golden): + @proc + def foo(a: size, b: size): + assert (a <= 3 or b <= 4) and (a + b == 4) + pass + + assert golden == stringify_proc_constraint(foo, True) From d0a1cdf85e062c565428e52a93bc0385c818c084 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Fri, 14 Mar 2025 02:46:13 -0400 Subject: [PATCH 07/24] finish transpiler --- requirements.txt | 3 +- src/exo/backend/LoopIR_transpiler.py | 355 +++++++++++++++++++++++++++ src/exo/core/extern.py | 3 + src/exo/libs/externs.py | 21 ++ 4 files changed, 381 insertions(+), 1 deletion(-) create mode 100644 src/exo/backend/LoopIR_transpiler.py diff --git a/requirements.txt b/requirements.txt index dbdf8f7a8..d27a4d023 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ build==1.2.2.post1 z3-solver==4.14.0.0 yapf==0.43.0 scipy==1.6.2 -hsnf==0.3.16 \ No newline at end of file +hsnf==0.3.16 +pythonmonkey==1.1.0 \ No newline at end of file diff --git a/src/exo/backend/LoopIR_transpiler.py b/src/exo/backend/LoopIR_transpiler.py new file mode 100644 index 000000000..cf4fee1b0 --- /dev/null +++ b/src/exo/backend/LoopIR_transpiler.py @@ -0,0 +1,355 @@ +from functools import reduce +from string import Template +from typing import Any, Iterable, Union + +from .. import Config + +from ..core.prelude import Sym +from ..core.LoopIR import LoopIR, T +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class BaseType: + loopir_type: Any + dtype: type + javascript_array_type: str + + +base_types = ( + BaseType(T.F16, np.float16, "Float16Array"), + BaseType(T.F32, np.float32, "Float32Array"), + BaseType(T.F64, np.float64, "Float64Array"), + BaseType(T.INT8, np.int8, "Int8Array"), + BaseType(T.UINT8, np.uint8, "Uint8Array"), + BaseType(T.UINT16, np.uint16, "Uint16Array"), + BaseType(T.INT32, np.int32, "Int32Array"), + BaseType(T.Num, np.float64, "Float64Array"), +) + + +def lookup_loopir_type(loopir_type: Any) -> BaseType: + return next( + base_type + for base_type in base_types + if isinstance(loopir_type, base_type.loopir_type) + ) + + +def lookup_dtype(dtype: type) -> BaseType: + return next(base_type for base_type in base_types if base_type.dtype == dtype) + + +@dataclass +class Constant: + name: str + + +@dataclass +class Reference: + name: str + + +@dataclass +class Dimension: + size: str + stride: str + + +@dataclass +class Tensor: + name: str + offset: str + dims: tuple[Dimension, ...] + + +ExoValue = Union[Constant, Reference, Tensor] + + +CONTEXT_OBJECT_NAME = "ctxt" + + +class Transpiler: + def __init__(self, proc: LoopIR.proc): + self.name_lookup: dict[Sym, ExoValue] = {} + self.js_lines: list[str] = [] + self.configs: set[tuple[Config, str]] = set() + self.buffer_args: list[Sym] = [] + self.transpile_proc(proc) + + def get_javascript_template(self) -> Template: + return Template("\n".join(self.js_lines)) + + def get_configs(self) -> tuple[tuple[Config, str], ...]: + return tuple(self.configs) + + def get_buffer_arg_order(self) -> tuple[Sym, ...]: + return tuple(self.buffer_args) + + def get_config_param_name(self, config: Config, field: str) -> str: + return f"config_{config.name()}_{field}" + + def get_stride_param_name(self, tensor_name: Sym, dim_idx: int): + return f"stride_{repr(tensor_name)}_{dim_idx}" + + def get_size_param_name(self, tensor_name: Sym, dim_idx: int): + return f"size_{repr(tensor_name)}_{dim_idx}" + + def assert_at_runtime(self, expr: str): + self.js_lines.append(f"if(!{expr})return 1;") + + def transpile_proc(self, proc: LoopIR.proc): + arg_values = [] + for arg in proc.args: + if arg.type.is_numeric(): + self.buffer_args.append(arg.name) + if arg.type.is_tensor_or_window(): + value = Tensor( + repr(arg.name), + "0", + tuple( + Dimension( + f"${self.get_size_param_name(arg.name, dim_idx)}", + f"${self.get_stride_param_name(arg.name, dim_idx)}", + ) + for dim_idx in range(len(arg.type.shape())) + ), + ) + else: + value = Reference(repr(arg.name)) + else: + value = Constant(f"${repr(arg.name)}") + arg_values.append(value) + self.js_lines.append( + f'(({",".join(repr(arg) for arg in self.buffer_args)})=>{{' + ) + ctxt_placeholder = len(self.js_lines) + self.js_lines.append(f"__placeholder__") + self.call_proc(proc, tuple(arg_values)) + self.js_lines.append("return 0;})") + configs = ",".join( + f"{self.get_config_param_name(config, field)}:${self.get_config_param_name(config, field)}" + for config, field in self.configs + ) + self.js_lines[ctxt_placeholder] = f"{CONTEXT_OBJECT_NAME}={{{configs}}}" + + def call_proc(self, proc: LoopIR.proc, arg_values: tuple[ExoValue, ...]): + for arg, arg_value in zip(proc.args, arg_values): + self.name_lookup[arg.name] = arg_value + if arg.type.is_tensor_or_window(): + assert isinstance(arg_value, Tensor) + for arg_dim, arg_dim_expr in zip(arg_value.dims, arg.type.shape()): + self.assert_at_runtime( + f"({arg_dim.size}=={self.transpile_expr(arg_dim_expr)})", + ) + + for pred in proc.preds: + self.assert_at_runtime(self.transpile_expr(pred)) + + for stmt in proc.body: + self.transpile_stmt(stmt) + + def get_index_expr(self, buf: Tensor, idxs: Iterable[LoopIR.expr]): + idx_exprs = tuple(self.transpile_expr(idx) for idx in idxs) + for idx_expr, dim in zip(idx_exprs, buf.dims): + self.assert_at_runtime(f"({idx_expr}<{dim.size}&&{idx_expr}>=0)") + relative_idx = reduce( + lambda dim1, dim2: f"{dim1}+{dim2}", + ( + f"Math.imul({idx_expr},{dim.stride})" + for idx_expr, dim in zip(idx_exprs, buf.dims) + ), + ) + return f"{relative_idx}+{buf.offset}" + + def transpile_stmt(self, stmt: LoopIR.stmt): + if isinstance(stmt, (LoopIR.Assign, LoopIR.Reduce)): + lhs_buffer = self.name_lookup[stmt.name] + rhs = self.transpile_expr(stmt.rhs) + if isinstance(lhs_buffer, Reference): + lhs = f"{lhs_buffer.name}[0]" + elif isinstance(lhs_buffer, Tensor): + lhs = f"{lhs_buffer.name}[{self.get_index_expr(lhs_buffer, stmt.idx)}]" + else: + assert False + if isinstance(stmt, LoopIR.Assign): + self.js_lines.append(f"{lhs}={rhs};") + else: + self.js_lines.append(f"{lhs}+={rhs};") + elif isinstance(stmt, LoopIR.WriteConfig): + config_name = self.get_config_param_name(stmt.config, stmt.field) + rhs = self.transpile_expr(stmt.rhs) + self.js_lines.append(f"{CONTEXT_OBJECT_NAME}[{config_name}]={rhs};") + self.configs.add((stmt.config, stmt.field)) + elif isinstance(stmt, LoopIR.Pass): + pass + elif isinstance(stmt, LoopIR.If): + cond = self.transpile_expr( + stmt.cond, + ) + self.js_lines.append(f"if({cond}){{") + for body_stmt in stmt.body: + self.transpile_stmt(body_stmt) + self.js_lines.append("}else{") + for else_stmt in stmt.orelse: + self.transpile_stmt(else_stmt) + self.js_lines.append("}") + elif isinstance(stmt, LoopIR.For): + iter_name = repr(stmt.iter) + iter_lo = self.transpile_expr(stmt.lo) + iter_hi = self.transpile_expr(stmt.hi) + self.name_lookup[stmt.iter] = Constant(iter_name) + self.js_lines.append( + f"for(let {iter_name}={iter_lo};{iter_name}<{iter_hi};{iter_name}++){{" + ) + for body_stmt in stmt.body: + self.transpile_stmt(body_stmt) + self.js_lines.append("}") + elif isinstance(stmt, LoopIR.Alloc): + assert stmt.type.is_numeric() + if stmt.type.is_tensor_or_window(): + tensor_name = repr(stmt.name) + buffer_type = lookup_loopir_type( + stmt.type.basetype() + ).javascript_array_type + dim_exprs = tuple(self.transpile_expr(dim) for dim in stmt.type.shape()) + for dim_expr in dim_exprs: + self.assert_at_runtime(f"({dim_expr}>=0)") + buffer_size = reduce( + lambda dim1, dim2: f"Math.imul({dim1},{dim2})", dim_exprs + ) + self.js_lines.append( + f"let {tensor_name}=new {buffer_type}({buffer_size});" + ) + dimensions: list[Dimension] = [] + for dim_idx, dim_expr in enumerate(dim_exprs): + self.assert_at_runtime(f"({dim_expr}>=0)") + stride_expr = reduce( + lambda dim1, dim2: f"Math.imul({dim1},{dim2})", + dim_exprs[dim_idx + 1 :], + "1", + ) + dimensions.append(Dimension(dim_expr, stride_expr)) + self.name_lookup[stmt.name] = Tensor( + tensor_name, "0", tuple(dimensions) + ) + else: + ref_name = repr(stmt.name) + buffer_type = lookup_loopir_type(stmt.type).javascript_array_type + self.js_lines.append(f"let {ref_name}=new {buffer_type}(1);") + self.name_lookup[stmt.name] = Reference(ref_name) + elif isinstance(stmt, LoopIR.Free): + pass + elif isinstance(stmt, LoopIR.Call): + self.call_proc( + stmt.f, + tuple( + ( + self.transpile_buffer_arg(arg_expr) + if arg_expr.type.is_numeric() + else Constant(self.transpile_expr(arg_expr)) + ) + for arg_expr in stmt.args + ), + ) + elif isinstance(stmt, LoopIR.WindowStmt): + self.name_lookup[stmt.name] = self.transpile_buffer_arg(stmt.rhs) + else: + assert False, "unsupported stmt" + + def transpile_buffer_arg(self, expr: LoopIR.expr) -> Union[Tensor, Reference]: + if isinstance(expr, LoopIR.Read): + assert len(expr.idx) == 0 + buf = self.name_lookup[expr.name] + assert isinstance(buf, (Tensor, Reference)) + return buf + elif isinstance(expr, LoopIR.WindowExpr): + base = self.name_lookup[expr.name] + assert isinstance(base, Tensor) + offset_expr = base.offset + window_dims = [] + for idx, dim in zip(expr.idx, base.dims): + if isinstance(idx, LoopIR.Interval): + lo_expr = self.transpile_expr(idx.lo) + hi_expr = self.transpile_expr(idx.hi) + self.assert_at_runtime( + f"(0<={lo_expr}&&{lo_expr}<={hi_expr}&&{hi_expr}<={dim.size})" + ) + offset_expr = f"({offset_expr}+Math.imul({lo_expr},{dim.stride}))" + size_expr = f"({hi_expr}-{lo_expr})" + window_dims.append(Dimension(size_expr, dim.stride)) + elif isinstance(idx, LoopIR.Point): + pt_expr = self.transpile_expr(idx.pt) + self.assert_at_runtime(f"(0<={pt_expr}&&{pt_expr}<{dim.size})") + offset_expr = f"({offset_expr}+Math.imul({pt_expr},{dim.stride}))" + else: + assert False, "not a window index" + return Tensor(base.name, offset_expr, tuple(window_dims)) + else: + assert False, "unsupported buffer expression" + + def transpile_expr( + self, + expr: LoopIR.expr, + ) -> str: + if isinstance(expr, LoopIR.Read): + buf = self.name_lookup[expr.name] + if isinstance(buf, Tensor): + return f"{buf.name}[{self.get_index_expr(buf, expr.idx)}]" + elif isinstance(buf, Reference): + return f"{buf.name}[0]" + else: + return buf.name + elif isinstance(expr, LoopIR.Const): + if isinstance(expr.val, (int, float)): + return f"{expr.val}" + elif isinstance(expr.val, bool): + return "true" if expr.val else "false" + else: + assert False, "unexpected const type" + elif isinstance(expr, LoopIR.USub): + return f"(-{self.transpile_expr(expr.arg)})" + elif isinstance(expr, LoopIR.BinOp): + lhs = self.transpile_expr(expr.lhs) + rhs = self.transpile_expr(expr.rhs) + is_int = ( + isinstance(expr.type, (T.INT8, T.UINT8, T.UINT16, T.INT32)) + or not expr.type.is_numeric() + ) + if expr.op in ["+", "-", "%", "<", ">", "<=", ">=", "=="]: + val = f"({lhs}{expr.op}{rhs})" + elif expr.op == "*": + val = f"Math.imul({lhs},{rhs})" if is_int else f"({lhs}*{rhs})" + elif expr.op == "/": + val = f"(({lhs}/{rhs})|0)" if is_int else f"({lhs}/{rhs})" + elif expr.op == "and": + val = f"({lhs}&&{rhs})" + elif expr.op == "or": + val = f"({lhs}||{rhs})" + else: + assert False, "invalid op" + if isinstance(expr.type, T.INT8): + return f"(({val}<<24)>>24)" + elif isinstance(expr.type, T.UINT8): + return f"({val}&0xFF)" + elif isinstance(expr.type, T.UINT16): + return f"({val}&0xFFFF)" + else: + return val + elif isinstance(expr, LoopIR.Extern): + return expr.f.transpile( + tuple(self.transpile_expr(arg) for arg in expr.args) + ) + elif isinstance(expr, LoopIR.WindowExpr): + assert False, "unexpected window expr" + elif isinstance(expr, LoopIR.StrideExpr): + buf = self.name_lookup[expr.name] + assert isinstance(buf, Tensor) + return buf.dims[expr.dim].stride + elif isinstance(expr, LoopIR.ReadConfig): + self.configs.add((expr.config, expr.field)) + return f"{CONTEXT_OBJECT_NAME}[{self.get_config_param_name(expr.config, expr.field)}]" + else: + assert False, "unexpected expr" diff --git a/src/exo/core/extern.py b/src/exo/core/extern.py index b1ae39d6d..979689f6d 100644 --- a/src/exo/core/extern.py +++ b/src/exo/core/extern.py @@ -32,5 +32,8 @@ def typecheck(self, args): def interpret(self, args): raise NotImplementedError() + def transpile(self, args): + raise NotImplementedError() + def compile(self, args, prim_type): raise NotImplementedError() diff --git a/src/exo/libs/externs.py b/src/exo/libs/externs.py index 4b125e415..155b14fde 100644 --- a/src/exo/libs/externs.py +++ b/src/exo/libs/externs.py @@ -24,6 +24,9 @@ def globl(self, prim_type): def interpret(self, args): return np.sin(args[0]) + def transpile(self, args): + return f"Math.sin({args[0]})" + def compile(self, args, prim_type): return f"sin(({prim_type}){args[0]})" @@ -62,6 +65,9 @@ def interpret(self, args): else: return 0 + def transpile(self, args): + return f"(({args[0]}>0)?{args[0]}:0)" + def compile(self, args, prim_type): return f"_relu_{prim_type}(({prim_type}){args[0]})" @@ -106,6 +112,9 @@ def interpret(self, args): else: return z + def transpile(self, args): + return f"(({args[0]}<{args[1]})?{args[2]}:{args[3]})" + def compile(self, args, prim_type): return f"_select_{prim_type}(({prim_type}){args[0]}, ({prim_type}){args[1]}, ({prim_type}){args[2]}, ({prim_type}){args[3]})" @@ -135,6 +144,9 @@ def globl(self, prim_type): def interpret(self, args): return np.exp(args[0]) + def transpile(self, args): + return f"Math.exp({args[0]})" + def compile(self, args, prim_type): return f"expf(({prim_type})({args[0]}))" @@ -165,6 +177,9 @@ def globl(self, prim_type): def interpret(self, args): return np.nanmax([args[0], args[1]]) + def transpile(self, args): + return f"(({args[0]}<{args[1]}&&{args[1]}=={args[1]})?{args[1]}:{args[0]})" + def compile(self, args, prim_type): return f"fmaxf(({prim_type})({args[0]}), ({prim_type})({args[1]}))" @@ -199,6 +214,9 @@ def globl(self, prim_type): def interpret(self, args): return 1 / (1 + np.exp(-args[0])) + def transpile(self, args): + return f"1/(1+Math.exp(-{args[0]}))" + def compile(self, args, prim_type): return f"sigmoid(({prim_type})({args[0]}))" @@ -228,6 +246,9 @@ def globl(self, prim_type): def interpret(self, args): return np.sqrt(args[0]) + def transpile(self, args): + return f"Math.sqrt({args[0]})" + def compile(self, args, prim_type): return f"sqrt(({prim_type})({args[0]}))" From f084b27619c6dcee25077eec4177a282bc50f465 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 18 Mar 2025 00:08:47 -0400 Subject: [PATCH 08/24] integrate transpiler into chexo --- src/exo/backend/LoopIR_transpiler.py | 185 ++++++++++++----------- src/exo/platforms/gemmini.py | 16 +- src/exo/rewrite/LoopIR_scheduling.py | 34 ++++- src/exo/rewrite/chexo.py | 213 +++++++++++++++++---------- src/exo/rewrite/constraint_solver.py | 17 ++- tests/test_chexo.py | 17 --- 6 files changed, 279 insertions(+), 203 deletions(-) diff --git a/src/exo/backend/LoopIR_transpiler.py b/src/exo/backend/LoopIR_transpiler.py index cf4fee1b0..0c422c040 100644 --- a/src/exo/backend/LoopIR_transpiler.py +++ b/src/exo/backend/LoopIR_transpiler.py @@ -2,7 +2,7 @@ from string import Template from typing import Any, Iterable, Union -from .. import Config +from ..core.configs import Config from ..core.prelude import Sym from ..core.LoopIR import LoopIR, T @@ -50,6 +50,7 @@ class Constant: @dataclass class Reference: name: str + is_config: bool @dataclass @@ -73,20 +74,20 @@ class Tensor: class Transpiler: def __init__(self, proc: LoopIR.proc): - self.name_lookup: dict[Sym, ExoValue] = {} - self.js_lines: list[str] = [] - self.configs: set[tuple[Config, str]] = set() - self.buffer_args: list[Sym] = [] - self.transpile_proc(proc) + self._name_lookup: dict[Sym, ExoValue] = {} + self._js_lines: list[str] = [] + self._configs: set[tuple[Config, str]] = set() + self._buffer_args: list[Sym] = [] + self._transpile_proc(proc) def get_javascript_template(self) -> Template: - return Template("\n".join(self.js_lines)) + return Template("\n".join(self._js_lines)) - def get_configs(self) -> tuple[tuple[Config, str], ...]: - return tuple(self.configs) + def get_configs(self) -> frozenset[tuple[Config, str]]: + return frozenset(self._configs) def get_buffer_arg_order(self) -> tuple[Sym, ...]: - return tuple(self.buffer_args) + return tuple(self._buffer_args) def get_config_param_name(self, config: Config, field: str) -> str: return f"config_{config.name()}_{field}" @@ -97,14 +98,14 @@ def get_stride_param_name(self, tensor_name: Sym, dim_idx: int): def get_size_param_name(self, tensor_name: Sym, dim_idx: int): return f"size_{repr(tensor_name)}_{dim_idx}" - def assert_at_runtime(self, expr: str): - self.js_lines.append(f"if(!{expr})return 1;") + def _assert_at_runtime(self, expr: str): + self._js_lines.append(f"if(!{expr})return [1,{CONTEXT_OBJECT_NAME}];") - def transpile_proc(self, proc: LoopIR.proc): + def _transpile_proc(self, proc: LoopIR.proc): arg_values = [] for arg in proc.args: if arg.type.is_numeric(): - self.buffer_args.append(arg.name) + self._buffer_args.append(arg.name) if arg.type.is_tensor_or_window(): value = Tensor( repr(arg.name), @@ -118,43 +119,43 @@ def transpile_proc(self, proc: LoopIR.proc): ), ) else: - value = Reference(repr(arg.name)) + value = Reference(repr(arg.name), False) else: value = Constant(f"${repr(arg.name)}") arg_values.append(value) - self.js_lines.append( - f'(({",".join(repr(arg) for arg in self.buffer_args)})=>{{' + self._js_lines.append( + f'(({",".join(repr(arg) for arg in self._buffer_args)})=>{{' ) - ctxt_placeholder = len(self.js_lines) - self.js_lines.append(f"__placeholder__") - self.call_proc(proc, tuple(arg_values)) - self.js_lines.append("return 0;})") + ctxt_placeholder = len(self._js_lines) + self._js_lines.append(f"__placeholder__") + self._call_proc(proc, tuple(arg_values)) + self._js_lines.append(f"return [0,{CONTEXT_OBJECT_NAME}];}})") configs = ",".join( f"{self.get_config_param_name(config, field)}:${self.get_config_param_name(config, field)}" - for config, field in self.configs + for config, field in self._configs ) - self.js_lines[ctxt_placeholder] = f"{CONTEXT_OBJECT_NAME}={{{configs}}}" + self._js_lines[ctxt_placeholder] = f"{CONTEXT_OBJECT_NAME}={{{configs}}}" - def call_proc(self, proc: LoopIR.proc, arg_values: tuple[ExoValue, ...]): + def _call_proc(self, proc: LoopIR.proc, arg_values: tuple[ExoValue, ...]): for arg, arg_value in zip(proc.args, arg_values): - self.name_lookup[arg.name] = arg_value + self._name_lookup[arg.name] = arg_value if arg.type.is_tensor_or_window(): assert isinstance(arg_value, Tensor) for arg_dim, arg_dim_expr in zip(arg_value.dims, arg.type.shape()): - self.assert_at_runtime( - f"({arg_dim.size}=={self.transpile_expr(arg_dim_expr)})", + self._assert_at_runtime( + f"({arg_dim.size}=={self._transpile_expr(arg_dim_expr)})", ) for pred in proc.preds: - self.assert_at_runtime(self.transpile_expr(pred)) + self._assert_at_runtime(self._transpile_expr(pred)) for stmt in proc.body: - self.transpile_stmt(stmt) + self._transpile_stmt(stmt) - def get_index_expr(self, buf: Tensor, idxs: Iterable[LoopIR.expr]): - idx_exprs = tuple(self.transpile_expr(idx) for idx in idxs) + def _get_index_expr(self, buf: Tensor, idxs: Iterable[LoopIR.expr]): + idx_exprs = tuple(self._transpile_expr(idx) for idx in idxs) for idx_expr, dim in zip(idx_exprs, buf.dims): - self.assert_at_runtime(f"({idx_expr}<{dim.size}&&{idx_expr}>=0)") + self._assert_at_runtime(f"({idx_expr}<{dim.size}&&{idx_expr}>=0)") relative_idx = reduce( lambda dim1, dim2: f"{dim1}+{dim2}", ( @@ -164,49 +165,51 @@ def get_index_expr(self, buf: Tensor, idxs: Iterable[LoopIR.expr]): ) return f"{relative_idx}+{buf.offset}" - def transpile_stmt(self, stmt: LoopIR.stmt): + def _transpile_stmt(self, stmt: LoopIR.stmt): if isinstance(stmt, (LoopIR.Assign, LoopIR.Reduce)): - lhs_buffer = self.name_lookup[stmt.name] - rhs = self.transpile_expr(stmt.rhs) + lhs_buffer = self._name_lookup[stmt.name] + rhs = self._transpile_expr(stmt.rhs) if isinstance(lhs_buffer, Reference): - lhs = f"{lhs_buffer.name}[0]" + lhs = ( + lhs_buffer.name if lhs_buffer.is_config else f"{lhs_buffer.name}[0]" + ) elif isinstance(lhs_buffer, Tensor): - lhs = f"{lhs_buffer.name}[{self.get_index_expr(lhs_buffer, stmt.idx)}]" + lhs = f"{lhs_buffer.name}[{self._get_index_expr(lhs_buffer, stmt.idx)}]" else: assert False if isinstance(stmt, LoopIR.Assign): - self.js_lines.append(f"{lhs}={rhs};") + self._js_lines.append(f"{lhs}={rhs};") else: - self.js_lines.append(f"{lhs}+={rhs};") + self._js_lines.append(f"{lhs}+={rhs};") elif isinstance(stmt, LoopIR.WriteConfig): config_name = self.get_config_param_name(stmt.config, stmt.field) - rhs = self.transpile_expr(stmt.rhs) - self.js_lines.append(f"{CONTEXT_OBJECT_NAME}[{config_name}]={rhs};") - self.configs.add((stmt.config, stmt.field)) + rhs = self._transpile_expr(stmt.rhs) + self._js_lines.append(f'{CONTEXT_OBJECT_NAME}["{config_name}"]={rhs};') + self._configs.add((stmt.config, stmt.field)) elif isinstance(stmt, LoopIR.Pass): pass elif isinstance(stmt, LoopIR.If): - cond = self.transpile_expr( + cond = self._transpile_expr( stmt.cond, ) - self.js_lines.append(f"if({cond}){{") + self._js_lines.append(f"if({cond}){{") for body_stmt in stmt.body: - self.transpile_stmt(body_stmt) - self.js_lines.append("}else{") + self._transpile_stmt(body_stmt) + self._js_lines.append("}else{") for else_stmt in stmt.orelse: - self.transpile_stmt(else_stmt) - self.js_lines.append("}") + self._transpile_stmt(else_stmt) + self._js_lines.append("}") elif isinstance(stmt, LoopIR.For): iter_name = repr(stmt.iter) - iter_lo = self.transpile_expr(stmt.lo) - iter_hi = self.transpile_expr(stmt.hi) - self.name_lookup[stmt.iter] = Constant(iter_name) - self.js_lines.append( + iter_lo = self._transpile_expr(stmt.lo) + iter_hi = self._transpile_expr(stmt.hi) + self._name_lookup[stmt.iter] = Constant(iter_name) + self._js_lines.append( f"for(let {iter_name}={iter_lo};{iter_name}<{iter_hi};{iter_name}++){{" ) for body_stmt in stmt.body: - self.transpile_stmt(body_stmt) - self.js_lines.append("}") + self._transpile_stmt(body_stmt) + self._js_lines.append("}") elif isinstance(stmt, LoopIR.Alloc): assert stmt.type.is_numeric() if stmt.type.is_tensor_or_window(): @@ -214,106 +217,114 @@ def transpile_stmt(self, stmt: LoopIR.stmt): buffer_type = lookup_loopir_type( stmt.type.basetype() ).javascript_array_type - dim_exprs = tuple(self.transpile_expr(dim) for dim in stmt.type.shape()) + dim_exprs = tuple( + self._transpile_expr(dim) for dim in stmt.type.shape() + ) for dim_expr in dim_exprs: - self.assert_at_runtime(f"({dim_expr}>=0)") + self._assert_at_runtime(f"({dim_expr}>=0)") buffer_size = reduce( lambda dim1, dim2: f"Math.imul({dim1},{dim2})", dim_exprs ) - self.js_lines.append( + self._js_lines.append( f"let {tensor_name}=new {buffer_type}({buffer_size});" ) dimensions: list[Dimension] = [] for dim_idx, dim_expr in enumerate(dim_exprs): - self.assert_at_runtime(f"({dim_expr}>=0)") + self._assert_at_runtime(f"({dim_expr}>=0)") stride_expr = reduce( lambda dim1, dim2: f"Math.imul({dim1},{dim2})", dim_exprs[dim_idx + 1 :], "1", ) dimensions.append(Dimension(dim_expr, stride_expr)) - self.name_lookup[stmt.name] = Tensor( + self._name_lookup[stmt.name] = Tensor( tensor_name, "0", tuple(dimensions) ) else: ref_name = repr(stmt.name) buffer_type = lookup_loopir_type(stmt.type).javascript_array_type - self.js_lines.append(f"let {ref_name}=new {buffer_type}(1);") - self.name_lookup[stmt.name] = Reference(ref_name) + self._js_lines.append(f"let {ref_name}=new {buffer_type}(1);") + self._name_lookup[stmt.name] = Reference(ref_name, False) elif isinstance(stmt, LoopIR.Free): pass elif isinstance(stmt, LoopIR.Call): - self.call_proc( + self._call_proc( stmt.f, tuple( ( - self.transpile_buffer_arg(arg_expr) + self._transpile_buffer_arg(arg_expr) if arg_expr.type.is_numeric() - else Constant(self.transpile_expr(arg_expr)) + else Constant(self._transpile_expr(arg_expr)) ) for arg_expr in stmt.args ), ) elif isinstance(stmt, LoopIR.WindowStmt): - self.name_lookup[stmt.name] = self.transpile_buffer_arg(stmt.rhs) + self._name_lookup[stmt.name] = self._transpile_buffer_arg(stmt.rhs) else: assert False, "unsupported stmt" - def transpile_buffer_arg(self, expr: LoopIR.expr) -> Union[Tensor, Reference]: + def _transpile_buffer_arg(self, expr: LoopIR.expr) -> Union[Tensor, Reference]: if isinstance(expr, LoopIR.Read): assert len(expr.idx) == 0 - buf = self.name_lookup[expr.name] + buf = self._name_lookup[expr.name] assert isinstance(buf, (Tensor, Reference)) return buf elif isinstance(expr, LoopIR.WindowExpr): - base = self.name_lookup[expr.name] + base = self._name_lookup[expr.name] assert isinstance(base, Tensor) offset_expr = base.offset window_dims = [] for idx, dim in zip(expr.idx, base.dims): if isinstance(idx, LoopIR.Interval): - lo_expr = self.transpile_expr(idx.lo) - hi_expr = self.transpile_expr(idx.hi) - self.assert_at_runtime( + lo_expr = self._transpile_expr(idx.lo) + hi_expr = self._transpile_expr(idx.hi) + self._assert_at_runtime( f"(0<={lo_expr}&&{lo_expr}<={hi_expr}&&{hi_expr}<={dim.size})" ) offset_expr = f"({offset_expr}+Math.imul({lo_expr},{dim.stride}))" size_expr = f"({hi_expr}-{lo_expr})" window_dims.append(Dimension(size_expr, dim.stride)) elif isinstance(idx, LoopIR.Point): - pt_expr = self.transpile_expr(idx.pt) - self.assert_at_runtime(f"(0<={pt_expr}&&{pt_expr}<{dim.size})") + pt_expr = self._transpile_expr(idx.pt) + self._assert_at_runtime(f"(0<={pt_expr}&&{pt_expr}<{dim.size})") offset_expr = f"({offset_expr}+Math.imul({pt_expr},{dim.stride}))" else: assert False, "not a window index" return Tensor(base.name, offset_expr, tuple(window_dims)) + elif isinstance(expr, LoopIR.ReadConfig): + self._configs.add((expr.config, expr.field)) + return Reference( + f'{CONTEXT_OBJECT_NAME}["{self.get_config_param_name(expr.config, expr.field)}"]', + True, + ) else: assert False, "unsupported buffer expression" - def transpile_expr( + def _transpile_expr( self, expr: LoopIR.expr, ) -> str: if isinstance(expr, LoopIR.Read): - buf = self.name_lookup[expr.name] + buf = self._name_lookup[expr.name] if isinstance(buf, Tensor): - return f"{buf.name}[{self.get_index_expr(buf, expr.idx)}]" + return f"{buf.name}[{self._get_index_expr(buf, expr.idx)}]" elif isinstance(buf, Reference): - return f"{buf.name}[0]" + return buf.name if buf.is_config else f"{buf.name}[0]" else: return buf.name elif isinstance(expr, LoopIR.Const): - if isinstance(expr.val, (int, float)): - return f"{expr.val}" - elif isinstance(expr.val, bool): + if isinstance(expr.val, bool): return "true" if expr.val else "false" + elif isinstance(expr.val, (int, float)): + return f"{expr.val}" else: assert False, "unexpected const type" elif isinstance(expr, LoopIR.USub): - return f"(-{self.transpile_expr(expr.arg)})" + return f"(-{self._transpile_expr(expr.arg)})" elif isinstance(expr, LoopIR.BinOp): - lhs = self.transpile_expr(expr.lhs) - rhs = self.transpile_expr(expr.rhs) + lhs = self._transpile_expr(expr.lhs) + rhs = self._transpile_expr(expr.rhs) is_int = ( isinstance(expr.type, (T.INT8, T.UINT8, T.UINT16, T.INT32)) or not expr.type.is_numeric() @@ -340,16 +351,16 @@ def transpile_expr( return val elif isinstance(expr, LoopIR.Extern): return expr.f.transpile( - tuple(self.transpile_expr(arg) for arg in expr.args) + tuple(self._transpile_expr(arg) for arg in expr.args) ) elif isinstance(expr, LoopIR.WindowExpr): assert False, "unexpected window expr" elif isinstance(expr, LoopIR.StrideExpr): - buf = self.name_lookup[expr.name] + buf = self._name_lookup[expr.name] assert isinstance(buf, Tensor) return buf.dims[expr.dim].stride elif isinstance(expr, LoopIR.ReadConfig): - self.configs.add((expr.config, expr.field)) - return f"{CONTEXT_OBJECT_NAME}[{self.get_config_param_name(expr.config, expr.field)}]" + self._configs.add((expr.config, expr.field)) + return f'{CONTEXT_OBJECT_NAME}["{self.get_config_param_name(expr.config, expr.field)}"]' else: assert False, "unexpected expr" diff --git a/src/exo/platforms/gemmini.py b/src/exo/platforms/gemmini.py index 500e1a60a..c52a83efa 100644 --- a/src/exo/platforms/gemmini.py +++ b/src/exo/platforms/gemmini.py @@ -620,7 +620,7 @@ def ld_i8_vector( src: [i8][16] @ DRAM, dst: [i8][16] @ GEMM_SCRATCH, ): - assert stride(dst, 0) == 16 + assert stride(dst, 0) == 1 for i in seq(0, 16): dst[i] = src[i] @@ -636,7 +636,7 @@ def do_ld_i8_vector( src: [i8][16] @ DRAM, dst: [i8][16] @ GEMM_SCRATCH, ): - assert stride(dst, 0) == 16 + assert stride(dst, 0) == 1 for i in seq(0, 16): dst[i] = src[i] @@ -736,8 +736,8 @@ def ld_acc_i32_vector( dst: [i32][n, 16] @ GEMM_ACCUM, ): assert n <= 16 - assert stride(dst, 0) == 1 - assert stride(src, 0) == 1 + assert stride(dst, 1) == 1 + assert stride(src, 1) == 1 for i in seq(0, n): for j in seq(0, 16): @@ -754,8 +754,8 @@ def do_ld_acc_i32_vector( dst: [i32][n, 16] @ GEMM_ACCUM, ): assert n <= 16 - assert stride(dst, 0) == 1 - assert stride(src, 0) == 1 + assert stride(dst, 1) == 1 + assert stride(src, 1) == 1 for i in seq(0, n): for j in seq(0, 16): @@ -1088,7 +1088,7 @@ def del_and_zero(p): def zero_i8_vector( dst: [i8][16] @ GEMM_SCRATCH, ): - assert stride(dst, 0) == 16 + assert stride(dst, 0) == 1 pass for i in seq(0, 16): @@ -1102,7 +1102,7 @@ def zero_i8_vector( def do_zero_i8_vector( dst: [i8][16] @ GEMM_SCRATCH, ): - assert stride(dst, 0) == 16 + assert stride(dst, 0) == 1 for i in seq(0, 16): dst[i] = 0.0 diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index ec3361e18..616b6bb63 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -1,6 +1,6 @@ import re from collections import ChainMap -from typing import List, Tuple, Optional +from typing import Callable, List, Literal, Tuple, Optional from ..core.LoopIR import ( LoopIR, @@ -368,6 +368,31 @@ def divide_expr(e, quot): # Scheduling directives +def do_check( + static_check: Callable[[], None], + dynamic_check: Callable[[], None], + mode: Literal["static", "dynamic", "both"], +): + if mode == "both": + e_static, e_dynamic = None, None + try: + static_check() + except Exception as e: + e_static = e + try: + dynamic_check() + except Exception as e: + e_dynamic = e + if (e_static is None) != (e_dynamic is None): + assert False, "fuzzer should match static analysis" + elif e_static is not None: + raise e_static + elif mode == "static": + static_check() + elif mode == "dynamic": + dynamic_check() + + # Take a conservative approach and allow stmt reordering only when they are # writing to different buffers # TODO: Do effectcheck's check_commutes-ish thing using SMT here @@ -376,8 +401,11 @@ def DoReorderStmt(f_cursor, s_cursor): raise SchedulingError( "expected the second statement to be directly after the first" ) - # Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node) - fuzz_reorder_stmts(f_cursor, s_cursor) + do_check( + lambda: Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node), + lambda: fuzz_reorder_stmts(f_cursor, s_cursor), + "both", + ) ir, fwd = s_cursor._move(f_cursor.before()) return ir, fwd diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index a92b4038d..a28876abb 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -1,5 +1,7 @@ from typing import Optional, Union +from ..backend.LoopIR_transpiler import Transpiler + from ..core.configs import Config from ..core.LoopIR import LoopIR, T @@ -18,6 +20,8 @@ DisjointConstraint, ) +from pythonmonkey import eval as js_eval + class LoopIRVisitor: def visit(self, node): @@ -66,17 +70,6 @@ def visit(self, node): self.visit_generic(node) -@dataclass -class ConfigVisitor(LoopIRVisitor): - config_reads: dict[tuple[str, str], LoopIR.type] = field(default_factory=lambda: {}) - - def visit(self, node): - if isinstance(node, LoopIR.ReadConfig): - self.config_reads[(node.config.name(), node.field)] = node.type - else: - self.visit_generic(node) - - @dataclass class UsedVariableVisitor(LoopIRVisitor): used_vars: set[Sym] = field(default_factory=lambda: set()) @@ -88,11 +81,16 @@ def visit(self, node): self.visit_generic(node) -def get_used_config_fields(fragment): - config_visitor = ConfigVisitor() - for stmt in fragment: - config_visitor.visit(stmt) - return config_visitor.config_reads +@dataclass +class Dimension: + size: int + stride: int + + +@dataclass +class Tensor: + data: np.ndarray + dims: tuple[Dimension, ...] def get_free_variables(type_map, mem_map, fragment): @@ -111,7 +109,7 @@ def get_free_variables(type_map, mem_map, fragment): def eval_tensor_dimension( - dim_expr: LoopIR.expr, arg_values: dict[Sym, Union[int, bool, float, np.ndarray]] + dim_expr: LoopIR.expr, arg_values: dict[Sym, Union[int, bool, float, Tensor]] ) -> int: if isinstance(dim_expr, LoopIR.Read): return arg_values[dim_expr.name] @@ -206,11 +204,11 @@ def collect_arg_size_constraints( @dataclass class TestCase: - arg_values: dict[Sym, Union[int, bool, float, np.ndarray]] - ctxt: dict[tuple[str, str], Union[int, bool, float, np.ndarray]] + arg_values: dict[Sym, Union[int, bool, float, Tensor]] + ctxt: dict[tuple[Config, str], Union[int, bool, float, Tensor]] -def generate_control_value(var_type: LoopIR.type): +def generate_control_value(var_type: LoopIR.type) -> Union[int, bool, float]: if isinstance(var_type, T.Bool): return np.random.rand() < 0.5 elif isinstance(var_type, (T.Size, T.Stride)): @@ -221,7 +219,7 @@ def generate_control_value(var_type: LoopIR.type): assert False, "not a control type" -def generate_numeric_value(var_type: LoopIR.type, shape: Optional[tuple[int]]): +def generate_numeric_value(var_type: LoopIR.type, shape: tuple[int, ...]) -> Tensor: if isinstance(var_type, (T.F32, T.Num)): dtype = np.float32 elif isinstance(var_type, T.F16): @@ -240,21 +238,26 @@ def generate_numeric_value(var_type: LoopIR.type, shape: Optional[tuple[int]]): assert False, "not a numeric type" if dtype in [np.int8, np.int32]: - return np.random.randint(-INT_BOUND, INT_BOUND, shape, dtype=dtype) + data = np.random.randint(-INT_BOUND, INT_BOUND, shape, dtype=dtype) elif dtype in [np.uint8, np.uint16]: - return np.random.randint(0, INT_BOUND, shape, dtype=dtype) + data = np.random.randint(0, INT_BOUND, shape, dtype=dtype) elif dtype in [np.float16, np.float32, np.float64]: - if shape is None: - return (np.random.rand() * 2 - 1) * FLOAT_BOUND - else: - return ((np.random.rand(*shape) * 2 - 1) * FLOAT_BOUND).astype(dtype) + data = ((np.random.rand(*shape) * 2 - 1) * FLOAT_BOUND).astype(dtype) else: assert False, "unreachable" + return Tensor( + data.flatten(), + tuple( + Dimension(dim_size, dim_stride / data.dtype.itemsize) + for dim_size, dim_stride in zip(data.shape, data.strides) + ), + ) + def generate_test_case( args: list[LoopIR.fnarg], - config_fields: dict[tuple[str, str], LoopIR.type], + config_fields: frozenset[tuple[Config, str]], constraint: DisjointConstraint, cm: ConstraintMaker, ) -> Optional[TestCase]: @@ -265,15 +268,16 @@ def generate_test_case( ) if solution is None: return None - for (config_name, field), field_type in config_fields.items(): - if (config_name, field) in solution.ctxt: - ctxt[(config_name, field)] = solution.ctxt[(config_name, field)] + for config, field in config_fields: + if (config, field) in solution.ctxt: + ctxt[(config, field)] = solution.ctxt[(config, field)] else: + field_type = config.lookup_type(field) if field_type.is_numeric(): val = generate_numeric_value(field_type, (1,)) else: val = generate_control_value(field_type) - ctxt[(config_name, field)] = val + ctxt[(config, field)] = val for arg in args: if not arg.type.is_numeric(): @@ -300,6 +304,68 @@ def generate_test_case( return TestCase(arg_values, ctxt) +@dataclass +class TestResult: + buffer_values: dict[Sym, np.ndarray] + ctxt_object: dict[str, Union[int, float]] + + +def run_test_case(test_case: TestCase, transpiled_proc: Transpiler) -> TestResult: + subs = {} + for arg_name, arg_value in test_case.arg_values.items(): + if isinstance(arg_value, Tensor): + for dim_idx, dim in enumerate(arg_value.dims): + subs[transpiled_proc.get_size_param_name(arg_name, dim_idx)] = str( + dim.size + ) + subs[transpiled_proc.get_stride_param_name(arg_name, dim_idx)] = str( + dim.stride + ) + elif isinstance(arg_value, bool): + subs[repr(arg_name)] = "true" if arg_value else "false" + elif isinstance(arg_value, (int, float)): + subs[repr(arg_name)] = str(arg_value) + else: + assert False + for (config, field), config_value in test_case.ctxt.items(): + if isinstance(config_value, Tensor): + assert config_value.data.shape == (1,) + subs[transpiled_proc.get_config_param_name(config, field)] = str( + config_value.data[0] + ) + elif isinstance(config_value, bool): + subs[transpiled_proc.get_config_param_name(config, field)] = ( + "true" if config_value else "false" + ) + elif isinstance(config_value, (int, float)): + subs[transpiled_proc.get_config_param_name(config, field)] = str( + config_value + ) + else: + assert False + + buffer_args = tuple( + test_case.arg_values[buffer_name].data.copy() + for buffer_name in transpiled_proc.get_buffer_arg_order() + ) + javascript = transpiled_proc.get_javascript_template().substitute(subs) + try: + [result, ctxt_object] = js_eval(javascript)(*buffer_args) + except Exception as e: + print(e) + + assert result == 0 + return TestResult( + { + buffer_name: buffer_value + for buffer_name, buffer_value in zip( + transpiled_proc.get_buffer_arg_order(), buffer_args + ) + }, + ctxt_object, + ) + + TEST_CASE_BOUND = 15 @@ -307,7 +373,7 @@ def fuzz_reorder_stmts(s1, s2): proc = s1.get_root() proc_type_visitor = TypeVisitor() proc_type_visitor.visit(proc) - config_fields = get_used_config_fields([s1._node, s2._node]) + cm = ConstraintMaker(proc_type_visitor.type_map) constraint = TRUE_CONSTRAINT for pred in proc.preds: @@ -327,10 +393,30 @@ def fuzz_reorder_stmts(s1, s2): args = [arg for arg in args if not arg.type.is_numeric()] + [ arg for arg in args if arg.type.is_numeric() ] + + transpiled_test1 = Transpiler( + LoopIR.proc( + name=proc.name, + args=args, + preds=[], + body=[s1._node, s2._node], + instr=None, + srcinfo=proc.srcinfo, + ) + ) + transpiled_test2 = Transpiler( + LoopIR.proc( + name=proc.name, + args=args, + preds=[], + body=[s2._node, s1._node], + instr=None, + srcinfo=proc.srcinfo, + ) + ) + + config_fields = transpiled_test1.get_configs() | transpiled_test2.get_configs() buffer_size_bound = MIN_BUFFER_SIZE_BOUND - print("start") - print(constraint.pretty_print()) - print("end") for _ in range(TEST_CASE_BOUND): test_case = generate_test_case( args, @@ -346,54 +432,21 @@ def fuzz_reorder_stmts(s1, s2): ) if test_case is None: if buffer_size_bound is None or buffer_size_bound >= MAX_BUFFER_SIZE_BOUND: - if buffer_size_bound is None: - print(constraint.pretty_print()) assert buffer_size_bound is not None buffer_size_bound = None else: buffer_size_bound = min(MAX_BUFFER_SIZE_BOUND, buffer_size_bound * 4) continue - arg_vals1 = test_case.arg_values - arg_vals2 = { - key: val.copy() if isinstance(val, np.ndarray) else val - for key, val in arg_vals1.items() - } - ctxt1 = test_case.ctxt - ctxt2 = { - key: val.copy() if isinstance(val, np.ndarray) else val - for key, val in ctxt1.items() - } - - try: - interpret1 = Interpreter( - LoopIR.proc( - name=proc.name, - args=args, - preds=[], - body=[s1._node, s2._node], - instr=None, - srcinfo=proc.srcinfo, - ), - arg_vals1, - ctxt1, - ) - interpret2 = Interpreter( - LoopIR.proc( - name=proc.name, - args=args, - preds=[], - body=[s2._node, s1._node], - instr=None, - srcinfo=proc.srcinfo, - ), - arg_vals2, - ctxt2, - ) - except Exception as e: - print(e) - for x in arg_vals1: - if not np.allclose(arg_vals1[x], arg_vals2[x]): + + out1 = run_test_case(test_case, transpiled_test1) + out2 = run_test_case(test_case, transpiled_test2) + for buffer_name in out1.buffer_values.keys() & out2.buffer_values.keys(): + if not np.allclose( + out1.buffer_values[buffer_name], out2.buffer_values[buffer_name] + ): raise SchedulingError("mismatch found") - for key, val in interpret1.ctxt.items(): - if key not in interpret2.ctxt or interpret2.ctxt[key] != val: + for ctxt_name in out1.ctxt_object & out2.ctxt_object.keys(): + if not np.allclose( + out1.ctxt_object[ctxt_name], out2.ctxt_object[ctxt_name] + ): raise SchedulingError("context mismatch found") diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index 38cd3e6b0..9f4f23326 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -1,7 +1,8 @@ from dataclasses import dataclass, field from typing import Literal, Union, Optional -from exo.core.prelude import Sym +from ..core.configs import Config +from ..core.prelude import Sym from ..core.LoopIR import LoopIR, T import numpy as np from scipy.optimize import linprog @@ -207,14 +208,14 @@ def apply_assignments(self, assignments: dict[Sym, int]) -> Optional[int]: @dataclass class Solution: - ctxt: dict[tuple[str, str], int] + ctxt: dict[tuple[Config, str], int] var_assignments: dict[Sym, int] class ConstraintMaker: def __init__(self, type_map: dict[Sym, LoopIR.type]): self.var_subs: dict[Sym, Expression] = {} - self.ctxt: dict[tuple[str, str], Expression] = {} + self.ctxt: dict[tuple[Config, str], Expression] = {} self.extra_constraints: list[Constraint] = [] self.stride_dummies: dict[tuple[Sym, int], Sym] = {} for sym, sym_type in type_map.items(): @@ -310,7 +311,7 @@ def make_constraint_terms( dummy = self.stride_dummies[(expr.name, expr.dim)] return (ConstraintTerm(1, (dummy,)),) elif isinstance(expr, LoopIR.ReadConfig): - if (expr.config.name(), expr.field) not in self.ctxt: + if (expr.config, expr.field) not in self.ctxt: field_type = expr.config.lookup_type(expr.field) var_sub_result = self.make_var_sub( f"{expr.config.name()}_{expr.field}", field_type @@ -318,8 +319,8 @@ def make_constraint_terms( assert ( var_sub_result is not None ), "constraints can only occur on control variables" - self.ctxt[(expr.config.name(), expr.field)] = var_sub_result - return self.ctxt[(expr.config.name(), expr.field)].terms + self.ctxt[(expr.config, expr.field)] = var_sub_result + return self.ctxt[(expr.config, expr.field)].terms else: assert False, f"unsupported expr" @@ -550,10 +551,10 @@ def solve_helper(): if result is not None: var_assignments[sym] = result ctxt = {} - for (config_name, field), sub in self.ctxt.items(): + for (config, field), sub in self.ctxt.items(): result = sub.apply_assignments(assignments) if result is not None: - ctxt[(config_name, field)] = result + ctxt[(config, field)] = result return Solution(ctxt, var_assignments) else: assignments = {} diff --git a/tests/test_chexo.py b/tests/test_chexo.py index 6221fd160..d311fc98c 100644 --- a/tests/test_chexo.py +++ b/tests/test_chexo.py @@ -3,7 +3,6 @@ from exo.rewrite.chexo import ( TypeVisitor, - get_used_config_fields, get_free_variables, collect_path_constraints, collect_arg_size_constraints, @@ -39,22 +38,6 @@ def foo(a: size, b: f32[a]): assert golden == f"Types:\n{types}\nMems:\n{mems}" -def test_get_used_config_fields(golden): - @config - class TestConfig: - a: f32 - b: size - c: f32 - - @proc - def foo(a: f32): - TestConfig.c = a - a = TestConfig.a - - used_configs = get_used_config_fields(foo._loopir_proc.body) - assert golden == stringify_dict(used_configs) - - def test_free_variables(golden): @proc def foo(a: size, b: f32[a]): From c9be60e6b449a30529d7458c7d2ab617c5ab59d2 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 18 Mar 2025 00:21:47 -0400 Subject: [PATCH 09/24] remove interpreter --- src/exo/rewrite/chexo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index a28876abb..a0b1dcdc9 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -8,7 +8,6 @@ from dataclasses import dataclass, field from ..core.prelude import Sym, SrcInfo from ..core.memory import DRAM, Memory -from ..backend.LoopIR_interpreter import Interpreter, run_interpreter import numpy as np from .new_eff import SchedulingError from .constraint_solver import ( From e91f9547bfb062d17ff2115340be35f7f4dbdf8a Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 18 Mar 2025 00:36:43 -0400 Subject: [PATCH 10/24] update dependencies --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index dcdb9cf74..9bb57c7d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,9 @@ install_requires = build>=1.2.1 z3-solver>=4.13.0.0 yapf>=0.40.2 + scipy>=1.6.2 + hsnf>=0.3.16 + pythonmonkey>=1.1.0 [options.packages.find] where = src From 36b0db08bec91f4664b53054fe9da046354112ee Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Mon, 5 May 2025 18:18:05 -0400 Subject: [PATCH 11/24] coverage finished --- src/exo/backend/LoopIR_transpiler.py | 361 ++++++++-- src/exo/backend/coverage.py | 362 ++++++++++ src/exo/rewrite/LoopIR_scheduling.py | 12 +- src/exo/rewrite/chexo.py | 358 ++++++---- src/exo/rewrite/constraint_solver.py | 636 +++++++++++------- tests/golden/test_transpiler/test_matmul.txt | 22 + .../test_transpiler/test_matmul_coverage.txt | 44 ++ .../test_nested_control_flow_coverage.txt | 30 + .../test_variable_length_array_coverage.txt | 24 + .../test_transpiler/test_window_coverage.txt | 14 + tests/test_apps.py | 3 +- tests/test_chexo.py | 13 - tests/test_transpiler.py | 105 +++ 13 files changed, 1554 insertions(+), 430 deletions(-) create mode 100644 src/exo/backend/coverage.py create mode 100644 tests/golden/test_transpiler/test_matmul.txt create mode 100644 tests/golden/test_transpiler/test_matmul_coverage.txt create mode 100644 tests/golden/test_transpiler/test_nested_control_flow_coverage.txt create mode 100644 tests/golden/test_transpiler/test_variable_length_array_coverage.txt create mode 100644 tests/golden/test_transpiler/test_window_coverage.txt create mode 100644 tests/test_transpiler.py diff --git a/src/exo/backend/LoopIR_transpiler.py b/src/exo/backend/LoopIR_transpiler.py index 0c422c040..501b6a39e 100644 --- a/src/exo/backend/LoopIR_transpiler.py +++ b/src/exo/backend/LoopIR_transpiler.py @@ -1,11 +1,21 @@ from functools import reduce +from itertools import chain from string import Template -from typing import Any, Iterable, Union +from typing import Any, Iterable, Optional, Union from ..core.configs import Config from ..core.prelude import Sym from ..core.LoopIR import LoopIR, T +from .coverage import ( + CoverageSkeleton, + CoverageSkeletonNode, + CoverageSkeletonBranch, + IndexedFiller, + MemoryAccess, + MemoryAccessPair, +) +from ..rewrite.constraint_solver import ConstraintMaker from dataclasses import dataclass import numpy as np @@ -61,24 +71,56 @@ class Dimension: @dataclass class Tensor: - name: str + name: Sym offset: str dims: tuple[Dimension, ...] + resize_placeholder: Optional[int] ExoValue = Union[Constant, Reference, Tensor] CONTEXT_OBJECT_NAME = "ctxt" +INITIAL_DYNAMIC_SIZE = 16 + + +@dataclass +class CoverageArgs: + cm: ConstraintMaker + + +class CoverageState: + def __init__(self, args: CoverageArgs, cov_placeholder: int): + self.cm: ConstraintMaker = args.cm + self.root: CoverageSkeletonNode = CoverageSkeletonNode(None, None, ()) + self.buffer_writes: dict[Sym, list[MemoryAccess]] = {} + self.buffer_reads: dict[Sym, list[MemoryAccess]] = {} + self.free_vars: list[Sym] = [] + self.cov_placeholder = cov_placeholder + + def make_skeleton(self) -> CoverageSkeleton: + aliasable_accesses: list[MemoryAccessPair] = [] + for sym, write_indices in self.buffer_writes.items(): + read_indices = self.buffer_reads[sym] if sym in self.buffer_reads else [] + for i, index1 in enumerate(write_indices): + for index2 in chain(write_indices[i + 1 :], read_indices): + aliasable_accesses.append(MemoryAccessPair(index1, index2)) + + return CoverageSkeleton( + (self.root,), tuple(aliasable_accesses), frozenset(self.free_vars) + ) class Transpiler: - def __init__(self, proc: LoopIR.proc): + def __init__(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArgs] = None): self._name_lookup: dict[Sym, ExoValue] = {} self._js_lines: list[str] = [] self._configs: set[tuple[Config, str]] = set() self._buffer_args: list[Sym] = [] - self._transpile_proc(proc) + self._coverage_state: Optional[CoverageState] = None + self._skeleton: Optional[CoverageSkeleton] = None + self.proc = proc # debug + self._transpile_proc(proc, coverage_args) def get_javascript_template(self) -> Template: return Template("\n".join(self._js_lines)) @@ -98,17 +140,25 @@ def get_stride_param_name(self, tensor_name: Sym, dim_idx: int): def get_size_param_name(self, tensor_name: Sym, dim_idx: int): return f"size_{repr(tensor_name)}_{dim_idx}" + def get_coverage_skeleton(self) -> Optional[CoverageSkeleton]: + return self._skeleton + def _assert_at_runtime(self, expr: str): - self._js_lines.append(f"if(!{expr})return [1,{CONTEXT_OBJECT_NAME}];") + self._js_lines.append(f"if(!{expr})return [1,{CONTEXT_OBJECT_NAME},{{}}];") + + def _make_placeholder(self) -> int: + placeholder_index = len(self._js_lines) + self._js_lines.append("") + return placeholder_index - def _transpile_proc(self, proc: LoopIR.proc): + def _transpile_proc(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArgs]): arg_values = [] for arg in proc.args: if arg.type.is_numeric(): self._buffer_args.append(arg.name) if arg.type.is_tensor_or_window(): value = Tensor( - repr(arg.name), + arg.name, "0", tuple( Dimension( @@ -117,6 +167,7 @@ def _transpile_proc(self, proc: LoopIR.proc): ) for dim_idx in range(len(arg.type.shape())) ), + None, ) else: value = Reference(repr(arg.name), False) @@ -126,34 +177,58 @@ def _transpile_proc(self, proc: LoopIR.proc): self._js_lines.append( f'(({",".join(repr(arg) for arg in self._buffer_args)})=>{{' ) - ctxt_placeholder = len(self._js_lines) - self._js_lines.append(f"__placeholder__") - self._call_proc(proc, tuple(arg_values)) - self._js_lines.append(f"return [0,{CONTEXT_OBJECT_NAME}];}})") + ctxt_placeholder = self._make_placeholder() + if coverage_args is not None: + self._coverage_state = CoverageState( + coverage_args, self._make_placeholder() + ) + self._call_proc( + proc, + tuple(arg_values), + None if self._coverage_state is None else self._coverage_state.root, + ) + coverage_object = "" + if self._coverage_state is not None: + skeleton = self._coverage_state.make_skeleton() + self._skeleton = skeleton + coverage_object = f"{{{','.join(sorted(repr(sym) for sym in self._skeleton.get_coverage_syms()))}}}" + for indexed_filler in sorted(set(skeleton.get_indexed_fillers())): + self._js_lines[indexed_filler.index] += indexed_filler.placefiller + self._js_lines.append(f"return [0,{CONTEXT_OBJECT_NAME},{coverage_object}];}})") configs = ",".join( f"{self.get_config_param_name(config, field)}:${self.get_config_param_name(config, field)}" for config, field in self._configs ) self._js_lines[ctxt_placeholder] = f"{CONTEXT_OBJECT_NAME}={{{configs}}}" - def _call_proc(self, proc: LoopIR.proc, arg_values: tuple[ExoValue, ...]): + def _call_proc( + self, + proc: LoopIR.proc, + arg_values: tuple[ExoValue, ...], + coverage_node: Optional[CoverageSkeletonNode], + ): for arg, arg_value in zip(proc.args, arg_values): self._name_lookup[arg.name] = arg_value if arg.type.is_tensor_or_window(): assert isinstance(arg_value, Tensor) for arg_dim, arg_dim_expr in zip(arg_value.dims, arg.type.shape()): self._assert_at_runtime( - f"({arg_dim.size}=={self._transpile_expr(arg_dim_expr)})", + f"({arg_dim.size}=={self._transpile_expr(arg_dim_expr, None)})", ) for pred in proc.preds: - self._assert_at_runtime(self._transpile_expr(pred)) + self._assert_at_runtime(self._transpile_expr(pred, None)) for stmt in proc.body: - self._transpile_stmt(stmt) + self._transpile_stmt(stmt, coverage_node) - def _get_index_expr(self, buf: Tensor, idxs: Iterable[LoopIR.expr]): - idx_exprs = tuple(self._transpile_expr(idx) for idx in idxs) + def _get_index_expr( + self, + buf: Tensor, + idxs: Iterable[LoopIR.expr], + coverage_node: Optional[CoverageSkeletonNode], + ): + idx_exprs = tuple(self._transpile_expr(idx, coverage_node) for idx in idxs) for idx_expr, dim in zip(idx_exprs, buf.dims): self._assert_at_runtime(f"({idx_expr}<{dim.size}&&{idx_expr}>=0)") relative_idx = reduce( @@ -165,16 +240,83 @@ def _get_index_expr(self, buf: Tensor, idxs: Iterable[LoopIR.expr]): ) return f"{relative_idx}+{buf.offset}" - def _transpile_stmt(self, stmt: LoopIR.stmt): + def _make_scalar_access_fillers(self, access_sym: Sym) -> tuple[IndexedFiller, ...]: + assert self._coverage_state is not None + mark_placeholder = self._make_placeholder() + mark_stmt = f"{repr(access_sym)}=true;" + decl_stmt = f"let {repr(access_sym)}=false;" + return ( + IndexedFiller(mark_placeholder, mark_stmt), + IndexedFiller(self._coverage_state.cov_placeholder, decl_stmt), + ) + + def _make_tensor_access_fillers( + self, access_sym: Sym, buffer: Tensor, idx: Iterable[LoopIR.expr] + ) -> tuple[IndexedFiller, ...]: + assert self._coverage_state is not None + mark_stmt = f"{repr(access_sym)}[{self._get_index_expr(buffer, idx, None)}]=1;" + mark_placeholder = self._make_placeholder() + base_buffer = self._name_lookup[buffer.name] + assert isinstance(base_buffer, Tensor) + base_dims = base_buffer.dims + base_size = reduce( + lambda dim1, dim2: f"Math.imul({dim1},{dim2})", + (dim.size for dim in base_dims), + ) + if buffer.resize_placeholder is None: + decl_stmt = f"let {repr(access_sym)}=new ArrayBuffer({base_size});" + return ( + IndexedFiller(mark_placeholder, mark_stmt), + IndexedFiller(self._coverage_state.cov_placeholder, decl_stmt), + ) + else: + temp_sym = Sym("temp") + decl_stmt = f"let {repr(access_sym)}=new ArrayBuffer(1,{{maxByteLength:{INITIAL_DYNAMIC_SIZE}}});" + resize_stmt = f"while({base_size}>{repr(access_sym)}.maxByteLength){{let {repr(temp_sym)}=new ArrayBuffer({repr(access_sym)}.byteLength,{{maxByteLength:2*{repr(access_sym)}.maxByteLength}});for(let i=0;i<{repr(access_sym)}.byteLength;i++){repr(temp_sym)}[i]={repr(access_sym)}[i];{repr(access_sym)}={repr(temp_sym)}}};{repr(access_sym)}.resize(Math.max({base_size},{repr(access_sym)}.byteLength));" + return ( + IndexedFiller(mark_placeholder, mark_stmt), + IndexedFiller(self._coverage_state.cov_placeholder, decl_stmt), + IndexedFiller(buffer.resize_placeholder, resize_stmt), + ) + + def _transpile_stmt( + self, stmt: LoopIR.stmt, coverage_node: Optional[CoverageSkeletonNode] + ): if isinstance(stmt, (LoopIR.Assign, LoopIR.Reduce)): lhs_buffer = self._name_lookup[stmt.name] - rhs = self._transpile_expr(stmt.rhs) + + if self._coverage_state is not None and coverage_node is not None: + write_sym = Sym("write") + if stmt.name not in self._coverage_state.buffer_writes: + self._coverage_state.buffer_writes[ + lhs_buffer.name if isinstance(lhs_buffer, Tensor) else stmt.name + ] = [] + self._coverage_state.buffer_writes[ + lhs_buffer.name if isinstance(lhs_buffer, Tensor) else stmt.name + ].append( + MemoryAccess( + write_sym, + coverage_node, + tuple( + self._coverage_state.cm.make_expression(idx) + for idx in stmt.idx + ), + ( + self._make_tensor_access_fillers( + write_sym, lhs_buffer, stmt.idx + ) + if isinstance(lhs_buffer, Tensor) + else self._make_scalar_access_fillers(write_sym) + ), + ) + ) + rhs = self._transpile_expr(stmt.rhs, coverage_node) if isinstance(lhs_buffer, Reference): lhs = ( lhs_buffer.name if lhs_buffer.is_config else f"{lhs_buffer.name}[0]" ) elif isinstance(lhs_buffer, Tensor): - lhs = f"{lhs_buffer.name}[{self._get_index_expr(lhs_buffer, stmt.idx)}]" + lhs = f"{repr(lhs_buffer.name)}[{self._get_index_expr(lhs_buffer, stmt.idx, coverage_node)}]" else: assert False if isinstance(stmt, LoopIR.Assign): @@ -183,32 +325,121 @@ def _transpile_stmt(self, stmt: LoopIR.stmt): self._js_lines.append(f"{lhs}+={rhs};") elif isinstance(stmt, LoopIR.WriteConfig): config_name = self.get_config_param_name(stmt.config, stmt.field) - rhs = self._transpile_expr(stmt.rhs) + rhs = self._transpile_expr(stmt.rhs, coverage_node) self._js_lines.append(f'{CONTEXT_OBJECT_NAME}["{config_name}"]={rhs};') self._configs.add((stmt.config, stmt.field)) elif isinstance(stmt, LoopIR.Pass): pass elif isinstance(stmt, LoopIR.If): - cond = self._transpile_expr( - stmt.cond, - ) + cond = self._transpile_expr(stmt.cond, coverage_node) self._js_lines.append(f"if({cond}){{") - for body_stmt in stmt.body: - self._transpile_stmt(body_stmt) - self._js_lines.append("}else{") - for else_stmt in stmt.orelse: - self._transpile_stmt(else_stmt) + + if self._coverage_state is not None and coverage_node is not None: + true_sym = Sym("true_case") + false_sym = Sym("false_case") + true_placeholder = self._make_placeholder() + cond_constraint = self._coverage_state.cm.make_constraint(stmt.cond) + true_node = CoverageSkeletonNode( + true_sym, + (coverage_node, cond_constraint), + ( + IndexedFiller( + self._coverage_state.cov_placeholder, + f"let {repr(true_sym)}=false;", + ), + IndexedFiller(true_placeholder, f"{repr(true_sym)}=true;"), + ), + ) + for body_stmt in stmt.body: + self._transpile_stmt(body_stmt, true_node) + self._js_lines.append("}else{") + false_placeholder = self._make_placeholder() + false_node = CoverageSkeletonNode( + false_sym, + (coverage_node, cond_constraint.invert()), + ( + IndexedFiller( + self._coverage_state.cov_placeholder, + f"let {repr(false_sym)}=false;", + ), + IndexedFiller(false_placeholder, f"{repr(false_sym)}=true;"), + ), + ) + for else_stmt in stmt.orelse: + self._transpile_stmt(else_stmt, false_node) + new_branch = CoverageSkeletonBranch(true_node, false_node) + coverage_node.branches.append(new_branch) + else: + for body_stmt in stmt.body: + self._transpile_stmt(body_stmt, None) + self._js_lines.append("}else{") + for else_stmt in stmt.orelse: + self._transpile_stmt(else_stmt, None) self._js_lines.append("}") elif isinstance(stmt, LoopIR.For): iter_name = repr(stmt.iter) - iter_lo = self._transpile_expr(stmt.lo) - iter_hi = self._transpile_expr(stmt.hi) + iter_lo = self._transpile_expr(stmt.lo, coverage_node) + iter_hi = self._transpile_expr(stmt.hi, coverage_node) self._name_lookup[stmt.iter] = Constant(iter_name) + + body_child, skip_child = None, None + if self._coverage_state is not None and coverage_node is not None: + body_sym = Sym("body") + skip_sym = Sym("skip") + loop_placeholder = self._make_placeholder() + body_constraint = ( + self._coverage_state.cm.make_constraint_from_inequality( + stmt.lo, stmt.iter, "<=" + ) + .lift_to_disjoint_constraint() + .intersect( + self._coverage_state.cm.make_constraint_from_inequality( + stmt.iter, stmt.hi, "<" + ).lift_to_disjoint_constraint() + ) + ) + skip_constraint = ( + self._coverage_state.cm.make_constraint_from_inequality( + stmt.lo, stmt.hi, ">=" + ).lift_to_disjoint_constraint() + ) + body_child = CoverageSkeletonNode( + body_sym, + (coverage_node, body_constraint), + ( + IndexedFiller( + self._coverage_state.cov_placeholder, + f"let {repr(body_sym)}=false;", + ), + IndexedFiller( + loop_placeholder, + f"{repr(body_sym)}||=({iter_lo}<{iter_hi});", + ), + ), + ) + skip_child = CoverageSkeletonNode( + skip_sym, + (coverage_node, skip_constraint), + ( + IndexedFiller( + self._coverage_state.cov_placeholder, + f"let {repr(skip_sym)}=false;", + ), + IndexedFiller( + loop_placeholder, + f"{repr(skip_sym)}||=({iter_lo}>={iter_hi});", + ), + ), + ) + self._coverage_state.free_vars.append(stmt.iter) + new_loop = CoverageSkeletonBranch(body_child, skip_child) + coverage_node.branches.append(new_loop) + self._js_lines.append( f"for(let {iter_name}={iter_lo};{iter_name}<{iter_hi};{iter_name}++){{" ) for body_stmt in stmt.body: - self._transpile_stmt(body_stmt) + self._transpile_stmt(body_stmt, body_child) self._js_lines.append("}") elif isinstance(stmt, LoopIR.Alloc): assert stmt.type.is_numeric() @@ -218,7 +449,8 @@ def _transpile_stmt(self, stmt: LoopIR.stmt): stmt.type.basetype() ).javascript_array_type dim_exprs = tuple( - self._transpile_expr(dim) for dim in stmt.type.shape() + self._transpile_expr(dim, coverage_node) + for dim in stmt.type.shape() ) for dim_expr in dim_exprs: self._assert_at_runtime(f"({dim_expr}>=0)") @@ -228,6 +460,8 @@ def _transpile_stmt(self, stmt: LoopIR.stmt): self._js_lines.append( f"let {tensor_name}=new {buffer_type}({buffer_size});" ) + resize_placeholder = len(self._js_lines) + self._js_lines.append("") dimensions: list[Dimension] = [] for dim_idx, dim_expr in enumerate(dim_exprs): self._assert_at_runtime(f"({dim_expr}>=0)") @@ -238,7 +472,7 @@ def _transpile_stmt(self, stmt: LoopIR.stmt): ) dimensions.append(Dimension(dim_expr, stride_expr)) self._name_lookup[stmt.name] = Tensor( - tensor_name, "0", tuple(dimensions) + stmt.name, "0", tuple(dimensions), resize_placeholder ) else: ref_name = repr(stmt.name) @@ -252,19 +486,24 @@ def _transpile_stmt(self, stmt: LoopIR.stmt): stmt.f, tuple( ( - self._transpile_buffer_arg(arg_expr) + self._transpile_buffer_arg(arg_expr, coverage_node) if arg_expr.type.is_numeric() - else Constant(self._transpile_expr(arg_expr)) + else Constant(self._transpile_expr(arg_expr, coverage_node)) ) for arg_expr in stmt.args ), + coverage_node, ) elif isinstance(stmt, LoopIR.WindowStmt): - self._name_lookup[stmt.name] = self._transpile_buffer_arg(stmt.rhs) + self._name_lookup[stmt.name] = self._transpile_buffer_arg( + stmt.rhs, coverage_node + ) else: assert False, "unsupported stmt" - def _transpile_buffer_arg(self, expr: LoopIR.expr) -> Union[Tensor, Reference]: + def _transpile_buffer_arg( + self, expr: LoopIR.expr, coverage_node: Optional[CoverageSkeletonNode] + ) -> Union[Tensor, Reference]: if isinstance(expr, LoopIR.Read): assert len(expr.idx) == 0 buf = self._name_lookup[expr.name] @@ -277,8 +516,8 @@ def _transpile_buffer_arg(self, expr: LoopIR.expr) -> Union[Tensor, Reference]: window_dims = [] for idx, dim in zip(expr.idx, base.dims): if isinstance(idx, LoopIR.Interval): - lo_expr = self._transpile_expr(idx.lo) - hi_expr = self._transpile_expr(idx.hi) + lo_expr = self._transpile_expr(idx.lo, coverage_node) + hi_expr = self._transpile_expr(idx.hi, coverage_node) self._assert_at_runtime( f"(0<={lo_expr}&&{lo_expr}<={hi_expr}&&{hi_expr}<={dim.size})" ) @@ -286,12 +525,14 @@ def _transpile_buffer_arg(self, expr: LoopIR.expr) -> Union[Tensor, Reference]: size_expr = f"({hi_expr}-{lo_expr})" window_dims.append(Dimension(size_expr, dim.stride)) elif isinstance(idx, LoopIR.Point): - pt_expr = self._transpile_expr(idx.pt) + pt_expr = self._transpile_expr(idx.pt, coverage_node) self._assert_at_runtime(f"(0<={pt_expr}&&{pt_expr}<{dim.size})") offset_expr = f"({offset_expr}+Math.imul({pt_expr},{dim.stride}))" else: assert False, "not a window index" - return Tensor(base.name, offset_expr, tuple(window_dims)) + return Tensor( + base.name, offset_expr, tuple(window_dims), base.resize_placeholder + ) elif isinstance(expr, LoopIR.ReadConfig): self._configs.add((expr.config, expr.field)) return Reference( @@ -302,13 +543,35 @@ def _transpile_buffer_arg(self, expr: LoopIR.expr) -> Union[Tensor, Reference]: assert False, "unsupported buffer expression" def _transpile_expr( - self, - expr: LoopIR.expr, + self, expr: LoopIR.expr, coverage_node: Optional[CoverageSkeletonNode] ) -> str: if isinstance(expr, LoopIR.Read): buf = self._name_lookup[expr.name] + if self._coverage_state is not None and coverage_node is not None: + read_sym = Sym("read") + if expr.name not in self._coverage_state.buffer_reads: + self._coverage_state.buffer_reads[ + buf.name if isinstance(buf, Tensor) else expr.name + ] = [] + self._coverage_state.buffer_reads[ + buf.name if isinstance(buf, Tensor) else expr.name + ].append( + MemoryAccess( + read_sym, + coverage_node, + tuple( + self._coverage_state.cm.make_expression(idx) + for idx in expr.idx + ), + ( + self._make_tensor_access_fillers(read_sym, buf, expr.idx) + if isinstance(buf, Tensor) + else self._make_scalar_access_fillers(read_sym) + ), + ) + ) if isinstance(buf, Tensor): - return f"{buf.name}[{self._get_index_expr(buf, expr.idx)}]" + return f"{repr(buf.name)}[{self._get_index_expr(buf, expr.idx, coverage_node)}]" elif isinstance(buf, Reference): return buf.name if buf.is_config else f"{buf.name}[0]" else: @@ -321,10 +584,10 @@ def _transpile_expr( else: assert False, "unexpected const type" elif isinstance(expr, LoopIR.USub): - return f"(-{self._transpile_expr(expr.arg)})" + return f"(-{self._transpile_expr(expr.arg, coverage_node)})" elif isinstance(expr, LoopIR.BinOp): - lhs = self._transpile_expr(expr.lhs) - rhs = self._transpile_expr(expr.rhs) + lhs = self._transpile_expr(expr.lhs, coverage_node) + rhs = self._transpile_expr(expr.rhs, coverage_node) is_int = ( isinstance(expr.type, (T.INT8, T.UINT8, T.UINT16, T.INT32)) or not expr.type.is_numeric() @@ -351,7 +614,7 @@ def _transpile_expr( return val elif isinstance(expr, LoopIR.Extern): return expr.f.transpile( - tuple(self._transpile_expr(arg) for arg in expr.args) + tuple(self._transpile_expr(arg, coverage_node) for arg in expr.args) ) elif isinstance(expr, LoopIR.WindowExpr): assert False, "unexpected window expr" diff --git a/src/exo/backend/coverage.py b/src/exo/backend/coverage.py new file mode 100644 index 000000000..06071aacb --- /dev/null +++ b/src/exo/backend/coverage.py @@ -0,0 +1,362 @@ +from dataclasses import dataclass, field +from itertools import groupby +from typing import Generator, Iterable, Optional, Union +import numpy as np + +from ..rewrite.constraint_solver import ( + Constraint, + ConstraintMaker, + DisjointConstraint, + TRUE_CONSTRAINT, + Expression, + Solution, +) +from ..core.prelude import Sym + + +@dataclass +class CoverageProgress: + covered_cases: int + total_cases: int + + def merge(self, other: "CoverageProgress") -> "CoverageProgress": + return CoverageProgress( + self.covered_cases + other.covered_cases, + self.total_cases + other.total_cases, + ) + + +@dataclass +class CoverageSolverState: + current_constraint: DisjointConstraint + is_base_constraint: bool + current_solution: Solution + cm: ConstraintMaker + free_vars: frozenset[Sym] + bound: int + search_limit: int + + def update_solution( + self, new_constraint: DisjointConstraint, new_solution: Solution + ): + return CoverageSolverState( + new_constraint, + False, + new_solution, + self.cm, + self.free_vars, + self.bound, + self.search_limit, + ) + + +@dataclass(order=True, frozen=True) +class IndexedFiller: + index: int + placefiller: str + + +@dataclass +class CoverageSkeletonBranch: + true_child: "CoverageSkeletonNode" + false_child: "CoverageSkeletonNode" + + def write_coverage_declarations(self, decls: dict[Sym, str]): + self.true_child.write_coverage_declarations(decls) + self.false_child.write_coverage_declarations(decls) + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + yield from self.true_child.get_indexed_fillers() + yield from self.false_child.get_indexed_fillers() + + def get_coverage_syms(self) -> frozenset[Sym]: + return ( + self.true_child.get_coverage_syms() | self.false_child.get_coverage_syms() + ) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + self.true_child.update_coverage(coverage_result) + self.false_child.update_coverage(coverage_result) + + def get_coverage_progress(self) -> CoverageProgress: + return self.true_child.get_coverage_progress().merge( + self.false_child.get_coverage_progress() + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + uncovered_path = None + if self.true_child.visited and not self.false_child.visited: + uncovered_path = False + elif self.false_child.visited and not self.true_child.visited: + uncovered_path = True + + if uncovered_path is not None: + path_constraint = ( + self.true_child.get_complete_constraint() + if uncovered_path + else self.false_child.get_complete_constraint() + ) + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms(), + state.free_vars, + ) + new_constraint = state.current_constraint.intersect( + path_constraint.rename_syms(sym_renaming) + ) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + if uncovered_path: + self.true_child.visited = True + else: + self.false_child.visited = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + elif self.true_child.visited and self.false_child.visited: + return self.false_child.solve_coverage( + self.true_child.solve_coverage(state) + ) + return state + + +@dataclass +class CoverageSkeletonNode: + coverage_sym: Optional[Sym] + parent_edge: Optional[tuple["CoverageSkeletonNode", DisjointConstraint]] + indexed_fillers: tuple[IndexedFiller, ...] + branches: list[CoverageSkeletonBranch] = field(default_factory=lambda: []) + visited: bool = False # mutable + + def write_coverage_declarations(self, decls: dict[Sym, str]): + if self.coverage_sym is not None: + decls[self.coverage_sym] = "false" + for branch in self.branches: + branch.write_coverage_declarations(decls) + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + for branch in self.branches: + yield from branch.get_indexed_fillers() + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset( + (self.coverage_sym,) if self.coverage_sym is not None else () + ).union(*tuple(branch.get_coverage_syms() for branch in self.branches)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + if self.coverage_sym is None: + self.visited = True + else: + covered = coverage_result[repr(self.coverage_sym)] + assert isinstance(covered, bool) + self.visited |= covered + for branch in self.branches: + branch.update_coverage(coverage_result) + + def get_complete_constraint(self) -> DisjointConstraint: + current_edge = self.parent_edge + result = TRUE_CONSTRAINT + while current_edge is not None: + result = result.intersect(current_edge[1]) + current_edge = current_edge[0].parent_edge + return result + + def get_coverage_progress(self) -> CoverageProgress: + result = CoverageProgress(1 if self.visited else 0, 1) + for branch in self.branches: + result = result.merge(branch.get_coverage_progress()) + return result + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + current_state = state + for branch in self.branches: + current_state = branch.solve_coverage(current_state) + return current_state + + +@dataclass +class MemoryAccess: + coverage_sym: Sym + node: CoverageSkeletonNode + index: tuple[Expression, ...] + indexed_fillers: tuple[IndexedFiller, ...] + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def make_renamed_constraint_and_indices( + self, state: CoverageSolverState + ) -> tuple[DisjointConstraint, tuple[Expression, ...]]: + path_constraint = self.node.get_complete_constraint() + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms().union( + *(index_expr.collect_syms() for index_expr in self.index) + ), + state.free_vars, + ) + return ( + path_constraint.rename_syms(sym_renaming), + tuple(index_expr.rename_syms(sym_renaming) for index_expr in self.index), + ) + + +@dataclass +class MemoryAccessPair: + access1: MemoryAccess + access2: MemoryAccess + visited_aliasing: bool = False # mutable + visited_nonaliasing: bool = False # mutable + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + yield from self.access1.get_indexed_fillers() + yield from self.access2.get_indexed_fillers() + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset((self.access1.coverage_sym, self.access2.coverage_sym)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + access1_view = coverage_result[repr(self.access1.coverage_sym)] + access2_view = coverage_result[repr(self.access2.coverage_sym)] + if isinstance(access1_view, memoryview): + assert isinstance(access2_view, memoryview) + access1_arr = np.asarray(access1_view) + access2_arr = np.asarray(access2_view) + self.visited_aliasing |= np.any(access1_arr & access2_arr) + self.visited_nonaliasing |= np.any(access1_arr & ~access2_arr) and np.any( + ~access1_arr & access2_arr + ) + else: + assert isinstance(access2_view, bool) + aliased = access1_view and access2_view + self.visited_aliasing |= aliased # nonaliasing not possible without tensor + + def get_coverage_progress(self) -> CoverageProgress: + return CoverageProgress( + (1 if self.visited_aliasing else 0) + + (1 if self.visited_nonaliasing else 0), + 2, + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + uncovered_path = None + if self.visited_aliasing and not self.visited_nonaliasing: + uncovered_path = False + elif self.visited_nonaliasing and not self.visited_aliasing: + uncovered_path = True + + if uncovered_path is not None: + ( + access1_path_constraint, + access1_indices, + ) = self.access1.make_renamed_constraint_and_indices(state) + ( + access2_path_constraint, + access2_indices, + ) = self.access2.make_renamed_constraint_and_indices(state) + path_constraints = access1_path_constraint.intersect( + access2_path_constraint + ) + alias_constraint = TRUE_CONSTRAINT + for index1, index2 in zip(access1_indices, access2_indices): + alias_constraint = alias_constraint.intersect( + Constraint( + Expression( + tuple(term.negate() for term in index1.terms) + index2.terms + ), + False, + ).lift_to_disjoint_constraint() + ) + if not uncovered_path: + alias_constraint = alias_constraint.invert() + new_constraint = state.current_constraint.intersect( + path_constraints + ).intersect(alias_constraint) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + if uncovered_path: + self.visited_aliasing = True + else: + self.visited_nonaliasing = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + return state + + +@dataclass +class CoverageSkeleton: + roots: tuple[CoverageSkeletonNode, ...] + aliasable_accesses: tuple[MemoryAccessPair, ...] + free_vars: frozenset[Sym] + + def merge(self, other: "CoverageSkeleton") -> "CoverageSkeleton": + return CoverageSkeleton( + self.roots + other.roots, + self.aliasable_accesses + other.aliasable_accesses, + self.free_vars | other.free_vars, + ) + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for root in self.roots: + yield from root.get_indexed_fillers() + for aliasable_access in self.aliasable_accesses: + yield from aliasable_access.get_indexed_fillers() + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset().union( + *tuple(root_node.get_coverage_syms() for root_node in self.roots), + *tuple( + aliasable_access.get_coverage_syms() + for aliasable_access in self.aliasable_accesses + ), + ) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + for root_node in self.roots: + root_node.update_coverage(coverage_result) + for aliasable_access in self.aliasable_accesses: + aliasable_access.update_coverage(coverage_result) + + def get_coverage_progress(self) -> CoverageProgress: + result = CoverageProgress(0, 0) + for root_node in self.roots: + result = root_node.get_coverage_progress() + for aliasable_access in self.aliasable_accesses: + result = result.merge(aliasable_access.get_coverage_progress()) + return result + + def solve_constraint_with_coverage( + self, + cm: ConstraintMaker, + base_constraint: DisjointConstraint, + *, + bound: int, + search_limit: int, + ) -> Optional[Solution]: + base_solution = cm.solve_constraint( + base_constraint, bound=bound, search_limit=search_limit + ) + if base_solution is None: + return None + state = CoverageSolverState( + base_constraint, + True, + base_solution, + cm, + self.free_vars, + bound, + search_limit, + ) + for aliasable_access in self.aliasable_accesses: + state = aliasable_access.solve_coverage(state) + for root_node in self.roots: + state = root_node.solve_coverage(state) + return state.current_solution diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index 616b6bb63..86dad4381 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -1,5 +1,6 @@ import re from collections import ChainMap +import traceback from typing import Callable, List, Literal, Tuple, Optional from ..core.LoopIR import ( @@ -218,7 +219,7 @@ def _replace_reads(ir, fwd, c, sym, repl, only_replace_attrs=True): c = fwd(c) todos = [] for rd in match_pattern(c, f"{repr(sym)}[_]", use_sym_id=True): - # Need [_] to pattern match against window expressions + # Need [_] to pattern match against window expressiontatic if c_repl := repl(rd): todos.append((rd, c_repl)) @@ -375,16 +376,21 @@ def do_check( ): if mode == "both": e_static, e_dynamic = None, None + trb_static, trb_dynamic = None, None try: static_check() except Exception as e: e_static = e + trb_static = traceback.format_exc() try: dynamic_check() except Exception as e: e_dynamic = e + trb_dynamic = traceback.format_exc() if (e_static is None) != (e_dynamic is None): - assert False, "fuzzer should match static analysis" + assert ( + False + ), f"fuzzer should match static analysis\ntrb_static: {trb_static}\n\ntrb_dynamic: {trb_dynamic}" elif e_static is not None: raise e_static elif mode == "static": @@ -404,7 +410,7 @@ def DoReorderStmt(f_cursor, s_cursor): do_check( lambda: Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node), lambda: fuzz_reorder_stmts(f_cursor, s_cursor), - "both", + "dynamic", ) ir, fwd = s_cursor._move(f_cursor.before()) return ir, fwd diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index a0b1dcdc9..c630f9165 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -1,6 +1,10 @@ -from typing import Optional, Union +from itertools import chain +from typing import Callable, Literal, Optional, Union -from ..backend.LoopIR_transpiler import Transpiler +from ..core.internal_cursors import Cursor, Block, Node + +from ..backend.LoopIR_transpiler import CoverageArgs, Transpiler +from ..backend.coverage import CoverageSkeleton from ..core.configs import Config @@ -17,6 +21,8 @@ ConstraintMaker, ConstraintTerm, DisjointConstraint, + Expression, + Solution, ) from pythonmonkey import eval as js_eval @@ -80,6 +86,63 @@ def visit(self, node): self.visit_generic(node) +class LoopIRModifier: + def visit(self, node): + return self.visit_generic(node) + + def visit_generic(self, node): + if ( + isinstance(node, LoopIR.proc) + or isinstance(node, LoopIR.instr) + or isinstance(node, LoopIR.fnarg) + or isinstance(node, LoopIR.stmt) + or isinstance(node, LoopIR.loop_mode) + or isinstance(node, LoopIR.expr) + or isinstance(node, LoopIR.w_access) + or isinstance(node, LoopIR.type) + ): + updates = {} + for field_name in dir(node): + if not field_name.startswith("_"): + field = getattr(node, field_name) + if isinstance(field, list): + new_field = field + for child_idx, child in enumerate(field): + new_child = self.visit(child) + if new_child != child: + if new_field == field: + new_field = field.copy() + new_field[child_idx] = new_child + else: + new_field = self.visit(field) + if new_field != field: + updates[field_name] = new_field + return node.update(**updates) if len(updates) != 0 else node + + +@dataclass +class ReadWriteSyms: + reduced_syms: set[Sym] + assigned_syms: set[Sym] + written_configs: set[tuple[Config, str]] + read_syms: set[Sym] + + +# @dataclass +# class LoopFlattener(LoopIRModifier): +# universal_var_types: dict[Sym, LoopIR.type] = field(default_factory=lambda: {}) +# loop_syms: Optional[ReadWriteSyms] = None + +# def visit(self, node): +# if isinstance(node, LoopIR.For): +# old_loop_syms = self.loop_syms +# new_node = self.visit_generic(node) +# self.loop_syms = old_loop_syms +# elif isinstance(node, LoopIR.Assign): +# elif isinstance(node, LoopIR.Reduce): +# elif isinstance(node, LoopIR.WriteConfig): + + @dataclass class Dimension: size: int @@ -92,12 +155,16 @@ class Tensor: dims: tuple[Dimension, ...] -def get_free_variables(type_map, mem_map, fragment): +def get_free_variables(type_map, mem_map, fragment: Union[Block, Node]): fragment_type_visitor = TypeVisitor() fragment_var_visitor = UsedVariableVisitor() - for stmt in fragment: - fragment_type_visitor.visit(stmt) - fragment_var_visitor.visit(stmt) + if isinstance(fragment, Block): + for fragment_node in fragment.resolve_all(): + fragment_type_visitor.visit(fragment_node) + fragment_var_visitor.visit(fragment_node) + else: + fragment_type_visitor.visit(fragment._node) + fragment_var_visitor.visit(fragment._node) for var in fragment_var_visitor.used_vars - fragment_type_visitor.type_map.keys(): fragment_var_visitor.visit(type_map[var]) return { @@ -143,64 +210,42 @@ def eval_tensor_dimension( CONTROL_VAL_BOUND = 128 -MIN_BUFFER_SIZE_BOUND = 16**1 -MAX_BUFFER_SIZE_BOUND = 16**6 SEARCH_LIMIT = 10 INT_BOUND = 128 FLOAT_BOUND = 32 -def collect_path_constraints(cursor, cm: ConstraintMaker) -> DisjointConstraint: +def collect_path_constraints( + cursor: Union[Block, Node], cm: ConstraintMaker +) -> DisjointConstraint: cur = cursor result = TRUE_CONSTRAINT last_attr = None while cur.depth() != 0: - if isinstance(cur._node, LoopIR.For): - result = result.intersect( - cm.make_constraint_from_inequality( - cur._node.iter, cur._node.lo, ">=" - ).lift_to_disjoint_constraint() - ) - result = result.intersect( - cm.make_constraint_from_inequality( - cur._node.iter, cur._node.hi, "<" - ).lift_to_disjoint_constraint() - ) - elif isinstance(cur._node, LoopIR.If): - constraint = cm.make_constraint(cur._node.cond) - if isinstance(last_attr, tuple) and last_attr[0] == "orelse": - result = result.intersect(constraint.invert()) - else: - result = result.intersect(constraint) - last_attr = cur._path[-1] + if isinstance(cur, Node): + last_attr = cur._path[-1] + if isinstance(cur._node, LoopIR.For): + result = result.intersect( + cm.make_constraint_from_inequality( + cur._node.iter, cur._node.lo, ">=" + ).lift_to_disjoint_constraint() + ) + result = result.intersect( + cm.make_constraint_from_inequality( + cur._node.iter, cur._node.hi, "<" + ).lift_to_disjoint_constraint() + ) + elif isinstance(cur._node, LoopIR.If): + constraint = cm.make_constraint(cur._node.cond) + if isinstance(last_attr, tuple) and last_attr[0] == "orelse": + result = result.intersect(constraint.invert()) + else: + result = result.intersect(constraint) cur = cur.parent() return result -def collect_arg_size_constraints( - args: list[LoopIR.fnarg], cm: ConstraintMaker, buffer_size_bound: int -) -> DisjointConstraint: - constraint = TRUE_CONSTRAINT - for arg in args: - if arg.type.is_tensor_or_window(): - dim_terms: tuple[ConstraintTerm, ...] = (ConstraintTerm(1, ()),) - for dim_expr in arg.type.shape(): - dim_terms = tuple( - dim_term.multiply(rhs_term) - for dim_term in dim_terms - for rhs_term in cm.make_constraint_terms(dim_expr) - ) - constraint = constraint.intersect( - Constraint( - tuple(term.negate() for term in dim_terms) - + (ConstraintTerm(buffer_size_bound, ()),), - True, - ).lift_to_disjoint_constraint() - ) - return constraint - - @dataclass class TestCase: arg_values: dict[Sym, Union[int, bool, float, Tensor]] @@ -255,15 +300,16 @@ def generate_numeric_value(var_type: LoopIR.type, shape: tuple[int, ...]) -> Ten def generate_test_case( - args: list[LoopIR.fnarg], + arg_types: dict[Sym, LoopIR.type], config_fields: frozenset[tuple[Config, str]], constraint: DisjointConstraint, + coverage_skeleton: CoverageSkeleton, cm: ConstraintMaker, ) -> Optional[TestCase]: ctxt = {} arg_values = {} - solution = cm.solve_constraint( - constraint, bound=CONTROL_VAL_BOUND, search_limit=SEARCH_LIMIT + solution = coverage_skeleton.solve_constraint_with_coverage( + cm, constraint, bound=INT_BOUND, search_limit=SEARCH_LIMIT ) if solution is None: return None @@ -278,27 +324,27 @@ def generate_test_case( val = generate_control_value(field_type) ctxt[(config, field)] = val - for arg in args: - if not arg.type.is_numeric(): - if arg.name in solution.var_assignments: - if isinstance(arg.type, T.Bool): - val = solution.var_assignments[arg.name] != 0 + for arg_name, arg_type in arg_types.items(): + if not arg_type.is_numeric(): + if arg_name in solution.var_assignments: + if isinstance(arg_type, T.Bool): + val = solution.var_assignments[arg_name] != 0 else: - val = solution.var_assignments[arg.name] + val = solution.var_assignments[arg_name] else: - val = generate_control_value(arg.type) - arg_values[arg.name] = val + val = generate_control_value(arg_type) + arg_values[arg_name] = val - for arg in args: - if arg.type.is_numeric(): - if arg.type.is_real_scalar(): + for arg_name, arg_type in arg_types.items(): + if arg_type.is_numeric(): + if arg_type.is_real_scalar(): shape = (1,) else: shape = tuple( eval_tensor_dimension(dim_expr, arg_values) - for dim_expr in arg.type.shape() + for dim_expr in arg_type.shape() ) - arg_values[arg.name] = generate_numeric_value(arg.type.basetype(), shape) + arg_values[arg_name] = generate_numeric_value(arg_type.basetype(), shape) return TestCase(arg_values, ctxt) @@ -307,9 +353,13 @@ def generate_test_case( class TestResult: buffer_values: dict[Sym, np.ndarray] ctxt_object: dict[str, Union[int, float]] + coverage_result: Optional[dict[str, Union[bool, memoryview, float]]] -def run_test_case(test_case: TestCase, transpiled_proc: Transpiler) -> TestResult: +def run_test_case( + test_case: TestCase, + transpiled_proc: Transpiler, +) -> Union[TestResult, Literal["failed"]]: subs = {} for arg_name, arg_value in test_case.arg_values.items(): if isinstance(arg_value, Tensor): @@ -349,11 +399,18 @@ def run_test_case(test_case: TestCase, transpiled_proc: Transpiler) -> TestResul ) javascript = transpiled_proc.get_javascript_template().substitute(subs) try: - [result, ctxt_object] = js_eval(javascript)(*buffer_args) + eval_info = js_eval(javascript)(*buffer_args) except Exception as e: - print(e) - - assert result == 0 + raise Exception( + f"javascript:\n{javascript}\nproc:\n{transpiled_proc.proc}" + ) from e + if transpiled_proc.get_coverage_skeleton() is None: + [result, ctxt_object] = eval_info + coverage_result = None + else: + [result, ctxt_object, coverage_result] = eval_info + if result != 0: + return "failed" return TestResult( { buffer_name: buffer_value @@ -362,83 +419,121 @@ def run_test_case(test_case: TestCase, transpiled_proc: Transpiler) -> TestResul ) }, ctxt_object, + coverage_result, ) -TEST_CASE_BOUND = 15 +@dataclass +class TestSpec: + proc: LoopIR.proc + constraint: DisjointConstraint + arg_types: dict[Sym, LoopIR.type] -def fuzz_reorder_stmts(s1, s2): - proc = s1.get_root() - proc_type_visitor = TypeVisitor() - proc_type_visitor.visit(proc) - - cm = ConstraintMaker(proc_type_visitor.type_map) - constraint = TRUE_CONSTRAINT - for pred in proc.preds: - constraint = constraint.intersect(cm.make_constraint(pred)) - constraint = constraint.intersect(collect_path_constraints(s1, cm)) - args = [ - LoopIR.fnarg( - name=var, - type=arg_type, - mem=DRAM if arg_mem is None else arg_mem, - srcinfo=SrcInfo("", 0), - ) - for var, (arg_type, arg_mem) in get_free_variables( - proc_type_visitor.type_map, proc_type_visitor.mem_map, [s1._node, s2._node] - ).items() - ] - args = [arg for arg in args if not arg.type.is_numeric()] + [ - arg for arg in args if arg.type.is_numeric() - ] - - transpiled_test1 = Transpiler( - LoopIR.proc( - name=proc.name, - args=args, - preds=[], - body=[s1._node, s2._node], - instr=None, - srcinfo=proc.srcinfo, - ) - ) - transpiled_test2 = Transpiler( - LoopIR.proc( - name=proc.name, +@dataclass +class TestScope: + scope: Union[Block, Node] + flatten_loops: bool + + def broaden(self) -> Optional["TestScope"]: + if self.scope.depth() == 0: + return TestScope(self.scope, False) if self.flatten_loops else None + else: + return TestScope(self.scope.parent(), self.flatten_loops) + + def transform(self, forward: Callable[[Cursor], Cursor]) -> "TestScope": + return TestScope(forward(self.scope), self.flatten_loops) + + def get_type_map(self) -> dict[Sym, LoopIR.type]: + root_proc = self.scope.get_root() + proc_type_visitor = TypeVisitor() + proc_type_visitor.visit(root_proc) + return proc_type_visitor.type_map + + def get_test_spec(self, cm: ConstraintMaker) -> TestSpec: + root_proc = self.scope.get_root() + proc_type_visitor = TypeVisitor() + proc_type_visitor.visit(root_proc) + + constraint = TRUE_CONSTRAINT + for pred in root_proc.preds: + constraint = constraint.intersect(cm.make_constraint(pred)) + constraint = constraint.intersect(collect_path_constraints(self.scope, cm)) + args = [ + LoopIR.fnarg( + name=var, + type=arg_type, + mem=DRAM if arg_mem is None else arg_mem, + srcinfo=SrcInfo("", 0), + ) + for var, (arg_type, arg_mem) in get_free_variables( + proc_type_visitor.type_map, + proc_type_visitor.mem_map, + self.scope, + ).items() + ] + args = [arg for arg in args if not arg.type.is_numeric()] + [ + arg for arg in args if arg.type.is_numeric() + ] + + proc = LoopIR.proc( + name=root_proc.name, args=args, preds=[], - body=[s2._node, s1._node], + body=( + [self.scope._node] + if isinstance(self.scope, Node) + else self.scope.resolve_all() + ), instr=None, - srcinfo=proc.srcinfo, + srcinfo=root_proc.srcinfo, ) - ) + arg_types = {arg.name: arg.type for arg in args} + return TestSpec(proc, constraint, arg_types) + + +TEST_CASE_BOUND = 15 + + +def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): + cur_scope = TestScope(starting_scope, True) + transformed = cur_scope.transform(fwd) + + cm = ConstraintMaker(cur_scope.get_type_map() | transformed.get_type_map()) + + spec1 = cur_scope.get_test_spec(cm) + spec2 = transformed.get_test_spec(cm) + + transpiled_test1 = Transpiler(spec1.proc, CoverageArgs(cm)) + transpiled_test2 = Transpiler(spec2.proc, CoverageArgs(cm)) config_fields = transpiled_test1.get_configs() | transpiled_test2.get_configs() - buffer_size_bound = MIN_BUFFER_SIZE_BOUND + + arg_types = spec1.arg_types | spec2.arg_types + constraint = spec1.constraint.union(spec2.constraint) + skeleton1, skeleton2 = ( + transpiled_test1.get_coverage_skeleton(), + transpiled_test2.get_coverage_skeleton(), + ) + assert skeleton1 is not None and skeleton2 is not None + coverage_skeleton = skeleton1.merge(skeleton2) for _ in range(TEST_CASE_BOUND): test_case = generate_test_case( - args, + arg_types, config_fields, - ( - constraint - if buffer_size_bound is None - else constraint.intersect( - collect_arg_size_constraints(args, cm, buffer_size_bound) - ) - ), + constraint, + coverage_skeleton, cm, ) if test_case is None: - if buffer_size_bound is None or buffer_size_bound >= MAX_BUFFER_SIZE_BOUND: - assert buffer_size_bound is not None - buffer_size_bound = None - else: - buffer_size_bound = min(MAX_BUFFER_SIZE_BOUND, buffer_size_bound * 4) continue out1 = run_test_case(test_case, transpiled_test1) out2 = run_test_case(test_case, transpiled_test2) + if out1 == "failed" or out2 == "failed": + raise SchedulingError("domain mismatch") + assert out1.coverage_result is not None and out2.coverage_result is not None + coverage_skeleton.update_coverage(out1.coverage_result | out2.coverage_result) for buffer_name in out1.buffer_values.keys() & out2.buffer_values.keys(): if not np.allclose( out1.buffer_values[buffer_name], out2.buffer_values[buffer_name] @@ -449,3 +544,12 @@ def fuzz_reorder_stmts(s1, s2): out1.ctxt_object[ctxt_name], out2.ctxt_object[ctxt_name] ): raise SchedulingError("context mismatch found") + + +def fuzz_reorder_stmts(s1: Node, s2: Node): + starting_scope = s1.as_block().expand(0, 1) + _, fwd = s2._move(s1.before()) + patched_fwd = lambda cursor: ( + fwd(cursor) if isinstance(cursor, Node) else fwd(s2).as_block().expand(0, 1) + ) + fuzz(starting_scope, patched_fwd) diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index 9f4f23326..086997ed8 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -23,20 +23,15 @@ def multiply(self, other) -> "ConstraintTerm": self.coefficient * other.coefficient, self.syms + other.syms ) - def apply_assignments( - self, assignments: dict[Sym, int] - ) -> Optional[tuple[int, Optional[Sym]]]: - target_sym = None - acc = self.coefficient + def substitute(self, assignments: dict[Sym, int]) -> "ConstraintTerm": + new_syms = [] + new_coefficient = self.coefficient for sym in self.syms: if sym in assignments: - acc *= assignments[sym] + new_coefficient *= assignments[sym] else: - if target_sym is None: - target_sym = sym - else: - return None - return (acc, target_sym) + new_syms.append(sym) + return ConstraintTerm(new_coefficient, tuple(new_syms)) def collect_nonlinear_syms(self) -> frozenset[Sym]: occurrences = set() @@ -48,6 +43,17 @@ def collect_nonlinear_syms(self) -> frozenset[Sym]: occurrences.add(sym) return frozenset(result) + def pretty_print(self) -> str: + return ( + f"{' * '.join([str(self.coefficient)] + [str(sym) for sym in self.syms])}" + ) + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "ConstraintTerm": + return ConstraintTerm( + self.coefficient, + tuple(lookup[sym] if sym in lookup else sym for sym in self.syms), + ) + @dataclass class LinearConstraint: @@ -57,28 +63,29 @@ class LinearConstraint: @dataclass -class Constraint: +class Expression: terms: tuple[ConstraintTerm, ...] - has_slack: bool - def apply_assignments( - self, assignments: dict[Sym, int] - ) -> Optional[LinearConstraint]: - coefficients = {} - offset = 0 + def substitute(self, assignments: dict[Sym, int]) -> "Expression": + coefficients: dict[tuple[Sym, ...], int] = {} for term in self.terms: - assign_result = term.apply_assignments(assignments) - if assign_result is None: - return None - else: - coefficient, sym = assign_result - if sym is None: - offset += coefficient - else: - if sym not in coefficients: - coefficients[sym] = 0 - coefficients[sym] += coefficient - return LinearConstraint(coefficients, offset, self.has_slack) + sub_term = term.substitute(assignments) + if sub_term.syms not in coefficients: + coefficients[sub_term.syms] = 0 + coefficients[sub_term.syms] += sub_term.coefficient + return Expression( + tuple( + ConstraintTerm(coefficient, syms) + for syms, coefficient in coefficients.items() + ) + ) + + def get_trivial_result(self) -> Optional[int]: + if len(self.terms) == 0: + return 0 + elif len(self.terms) == 1 and len(self.terms[0].syms) == 0: + return self.terms[0].coefficient + return None def collect_syms(self) -> frozenset[Sym]: return frozenset(sym for term in self.terms for sym in term.syms) @@ -88,26 +95,67 @@ def collect_nonlinear_syms(self) -> frozenset[Sym]: *[term.collect_nonlinear_syms() for term in self.terms] ) + def pretty_print(self): + return " + ".join([term.pretty_print() for term in self.terms]) + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "Expression": + return Expression(tuple(term.rename_syms(lookup) for term in self.terms)) + + +@dataclass +class Constraint: + lhs: Expression + has_slack: bool + + def linearize(self, assignments: dict[Sym, int]) -> Optional[LinearConstraint]: + new_lhs = self.lhs.substitute(assignments) + offset = 0 + coefficients = {} + for term in new_lhs.terms: + if len(term.syms) == 0: + offset += term.coefficient + elif len(term.syms) == 1: + coefficients[term.syms[0]] = term.coefficient + else: + return None + return LinearConstraint(coefficients, offset, self.has_slack) + + def collect_syms(self) -> frozenset[Sym]: + return self.lhs.collect_syms() + + def collect_nonlinear_syms(self) -> frozenset[Sym]: + return self.lhs.collect_nonlinear_syms() + def lift_to_disjoint_constraint(self) -> "DisjointConstraint": return DisjointConstraint((ConstraintClause((self,)),)) def invert(self) -> "DisjointConstraint": if self.has_slack: return Constraint( - tuple(term.negate() for term in self.terms) + (ConstraintTerm(-1, ()),), + Expression( + tuple(term.negate() for term in self.lhs.terms) + + (ConstraintTerm(-1, ()),) + ), True, ).lift_to_disjoint_constraint() else: return DisjointConstraint( ( ConstraintClause( - (Constraint(self.terms + (ConstraintTerm(-1, ()),), True),) + ( + Constraint( + Expression(self.lhs.terms + (ConstraintTerm(-1, ()),)), + True, + ), + ) ), ConstraintClause( ( Constraint( - tuple(term.negate() for term in self.terms) - + (ConstraintTerm(-1, ()),), + Expression( + tuple(term.negate() for term in self.lhs.terms) + + (ConstraintTerm(-1, ()),) + ), True, ), ) @@ -116,15 +164,19 @@ def invert(self) -> "DisjointConstraint": ) def pretty_print(self) -> str: - return ( - " + ".join( - [ - f"{' * '.join([str(term.coefficient)] + [str(sym) for sym in term.syms])}" - for term in self.terms - ] - ) - + f" {'>=' if self.has_slack else '=='} 0" - ) + return f"{self.lhs.pretty_print()} {'>=' if self.has_slack else '=='} 0" + + def substitute(self, assignments: dict[Sym, int]) -> "Constraint": + return Constraint(self.lhs.substitute(assignments), self.has_slack) + + def get_trivial_result(self) -> Optional[bool]: + lhs_result = self.lhs.get_trivial_result() + if lhs_result is not None: + return (lhs_result >= 0 and self.has_slack) or lhs_result == 0 + return None + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "Constraint": + return Constraint(self.lhs.rename_syms(lookup), self.has_slack) @dataclass @@ -148,6 +200,34 @@ def pretty_print(self) -> str: ] return "\n".join(lines) + def collect_syms(self) -> frozenset[Sym]: + return frozenset().union( + *(constraint.collect_syms() for constraint in self.constraints) + ) + + def substitute(self, assignments: dict[Sym, int]) -> "ConstraintClause": + new_constraints = [] + for constraint in self.constraints: + new_constraint = constraint.substitute(assignments) + trivial_result = new_constraint.get_trivial_result() + if trivial_result is None: + new_constraints.append(new_constraint) + elif not trivial_result: + return ConstraintClause((new_constraint,)) + return ConstraintClause(tuple(new_constraints)) + + def get_trivial_result(self) -> Optional[bool]: + if len(self.constraints) == 0: + return True + elif len(self.constraints) == 1: + return self.constraints[0].get_trivial_result() + return None + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "ConstraintClause": + return ConstraintClause( + tuple(constraint.rename_syms(lookup) for constraint in self.constraints) + ) + @dataclass class DisjointConstraint: @@ -182,34 +262,53 @@ def pretty_print(self) -> str: ] return "\n".join(lines) + def collect_syms(self) -> frozenset[Sym]: + return frozenset().union(*(clause.collect_syms() for clause in self.clauses)) -TRUE_CONSTRAINT = DisjointConstraint((ConstraintClause(()),)) -FALSE_CONSTRAINT = DisjointConstraint(()) + def substitute(self, assignments: dict[Sym, int]) -> "DisjointConstraint": + new_clauses = [] + for clause in self.clauses: + new_clause = clause.substitute(assignments) + trivial_result = new_clause.get_trivial_result() + if trivial_result is None: + new_clauses.append(new_clause) + elif trivial_result: + return DisjointConstraint((new_clause,)) + return DisjointConstraint(tuple(new_clauses)) + + def get_trivial_result(self) -> Optional[bool]: + if len(self.clauses) == 0: + return False + elif len(self.clauses) == 1: + return self.clauses[0].get_trivial_result() + return None + def rename_syms(self, lookup: dict[Sym, Sym]) -> "DisjointConstraint": + return DisjointConstraint( + tuple(clause.rename_syms(lookup) for clause in self.clauses) + ) -@dataclass -class Expression: - terms: tuple[ConstraintTerm, ...] - def apply_assignments(self, assignments: dict[Sym, int]) -> Optional[int]: - result = 0 - for term in self.terms: - assign_result = term.apply_assignments(assignments) - if assign_result is None: - return None - else: - coeff, target = assign_result - if target is None: - result += coeff - else: - return None - return result +TRUE_CONSTRAINT = DisjointConstraint((ConstraintClause(()),)) +FALSE_CONSTRAINT = DisjointConstraint(()) @dataclass class Solution: ctxt: dict[tuple[Config, str], int] var_assignments: dict[Sym, int] + substitutions: dict[Sym, int] + + def merge_solutions(self, other: "Solution", other_renaming: dict[Sym, Sym]): + return Solution( + self.ctxt, + self.var_assignments, + { + other_renaming[key] if key in other_renaming else key: value + for key, value in other.substitutions.items() + } + | self.substitutions, + ) class ConstraintMaker: @@ -217,6 +316,7 @@ def __init__(self, type_map: dict[Sym, LoopIR.type]): self.var_subs: dict[Sym, Expression] = {} self.ctxt: dict[tuple[Config, str], Expression] = {} self.extra_constraints: list[Constraint] = [] + self.hidden_vars: set[Sym] = set() self.stride_dummies: dict[tuple[Sym, int], Sym] = {} for sym, sym_type in type_map.items(): var_sub_result = self.make_var_sub(sym.name(), sym_type) @@ -238,9 +338,11 @@ def make_var_sub(self, name: str, var_type: LoopIR.type) -> Optional[Expression] sym = Sym(name) self.extra_constraints.append( Constraint( - ( - ConstraintTerm(-1, (sym,)), - ConstraintTerm(1, ()), + Expression( + ( + ConstraintTerm(-1, (sym,)), + ConstraintTerm(1, ()), + ) ), True, ) @@ -249,59 +351,70 @@ def make_var_sub(self, name: str, var_type: LoopIR.type) -> Optional[Expression] else: return None - def make_constraint_terms( - self, expr: Union[LoopIR.expr, Sym] - ) -> tuple[ConstraintTerm, ...]: + def make_expression(self, expr: Union[LoopIR.expr, Sym]) -> Expression: # expect that expr is int type if isinstance(expr, Sym): - return self.var_subs[expr].terms + return self.var_subs[expr] elif isinstance(expr, LoopIR.Read): assert ( len(expr.idx) == 0 ), "indexing not supported in assertions (yet, todo)" - return self.var_subs[expr.name].terms + return self.var_subs[expr.name] elif isinstance(expr, LoopIR.Const): - return (ConstraintTerm(expr.val, ()),) + return Expression((ConstraintTerm(expr.val, ()),)) elif isinstance(expr, LoopIR.USub): - return tuple(term.negate() for term in self.make_constraint_terms(expr.arg)) + return Expression( + tuple(term.negate() for term in self.make_expression(expr.arg).terms) + ) elif isinstance(expr, LoopIR.BinOp): # TODO: support mod and div using extra variables - lhs_terms = self.make_constraint_terms(expr.lhs) - rhs_terms = self.make_constraint_terms(expr.rhs) + lhs_terms = self.make_expression(expr.lhs).terms + rhs_terms = self.make_expression(expr.rhs).terms if expr.op == "+": - return lhs_terms + rhs_terms + return Expression(lhs_terms + rhs_terms) elif expr.op == "-": - return lhs_terms + tuple(term.negate() for term in rhs_terms) + return Expression( + lhs_terms + tuple(term.negate() for term in rhs_terms) + ) elif expr.op == "*": - return tuple( - lhs_term.multiply(rhs_term) - for lhs_term in lhs_terms - for rhs_term in rhs_terms + return Expression( + tuple( + lhs_term.multiply(rhs_term) + for lhs_term in lhs_terms + for rhs_term in rhs_terms + ) ) elif expr.op in ["/", "%"]: div, rem = Sym("div"), Sym("rem") + self.hidden_vars.update((div, rem)) self.extra_constraints.append( Constraint( - tuple(lhs_term.negate() for lhs_term in lhs_terms) - + (ConstraintTerm(1, (rem,)),) - + tuple( - rhs_term.multiply(ConstraintTerm(1, (div,))) - for rhs_term in rhs_terms + Expression( + tuple(lhs_term.negate() for lhs_term in lhs_terms) + + (ConstraintTerm(1, (rem,)),) + + tuple( + rhs_term.multiply(ConstraintTerm(1, (div,))) + for rhs_term in rhs_terms + ) ), False, ) ) self.extra_constraints.append( Constraint( - ( - ConstraintTerm(-1, (rem,)), - ConstraintTerm(-1, ()), - ) - + rhs_terms, + Expression( + ( + ConstraintTerm(-1, (rem,)), + ConstraintTerm(-1, ()), + ) + + rhs_terms + ), True, ) ) - return (ConstraintTerm(1, (rem if expr.op == "%" else div,)),) + return Expression( + (ConstraintTerm(1, (rem if expr.op == "%" else div,)),) + ) else: assert False, f"unsupported op in assertion: {expr.op}" elif isinstance(expr, LoopIR.StrideExpr): @@ -309,7 +422,7 @@ def make_constraint_terms( new_sym = Sym("stride") self.stride_dummies[(expr.name, expr.dim)] = new_sym dummy = self.stride_dummies[(expr.name, expr.dim)] - return (ConstraintTerm(1, (dummy,)),) + return Expression((ConstraintTerm(1, (dummy,)),)) elif isinstance(expr, LoopIR.ReadConfig): if (expr.config, expr.field) not in self.ctxt: field_type = expr.config.lookup_type(expr.field) @@ -320,7 +433,7 @@ def make_constraint_terms( var_sub_result is not None ), "constraints can only occur on control variables" self.ctxt[(expr.config, expr.field)] = var_sub_result - return self.ctxt[(expr.config, expr.field)].terms + return self.ctxt[(expr.config, expr.field)] else: assert False, f"unsupported expr" @@ -347,9 +460,11 @@ def make_constraint( elif isinstance(expr, LoopIR.Read): assert len(expr.idx) == 0, "cannot index into boolean" return Constraint( - ( - ConstraintTerm(1, (expr.name,)), - ConstraintTerm(-1, ()), + Expression( + ( + ConstraintTerm(1, (expr.name,)), + ConstraintTerm(-1, ()), + ) ), True, ).lift_to_disjoint_constraint() @@ -361,8 +476,8 @@ def make_constraint( def make_constraint_from_inequality( self, lhs: Union[LoopIR.expr, Sym], rhs: Union[LoopIR.expr, Sym], op: str ) -> Constraint: - lhs_terms = self.make_constraint_terms(lhs) - rhs_terms = self.make_constraint_terms(rhs) + lhs_terms = self.make_expression(lhs).terms + rhs_terms = self.make_expression(rhs).terms has_slack = True if op == "<": terms = ( @@ -385,140 +500,145 @@ def make_constraint_from_inequality( terms = rhs_terms + tuple(term.negate() for term in lhs_terms) else: assert False, "boolean ops expected" - return Constraint(terms, has_slack) - - def solve_constraint( - self, - disjoint_constraint: DisjointConstraint, - *, - bound: int, - search_limit: int, - seed: Optional[int] = None, - ) -> Optional[Solution]: - if seed is not None: - np.random.seed(seed=seed) - if len(disjoint_constraint.clauses) == 0: - return None - chosen_clause = np.random.choice(list(disjoint_constraint.clauses)) - assert isinstance(chosen_clause, ConstraintClause) - all_constraints = chosen_clause.constraints + tuple(self.extra_constraints) - assignments = {} + return Constraint(Expression(terms), has_slack) + + def _make_solution_from_assignments(self, assignments: dict[Sym, int]) -> Solution: + var_assignments = {} + for sym, sub in self.var_subs.items(): + result = sub.substitute(assignments).get_trivial_result() + if result is not None: + var_assignments[sym] = result + ctxt = {} + for (config, field), sub in self.ctxt.items(): + result = sub.substitute(assignments).get_trivial_result() + if result is not None: + ctxt[(config, field)] = result + return Solution(ctxt, var_assignments, assignments) + + def _solve_for_assignments( + self, all_constraints: tuple[Constraint, ...], bound: int + ) -> Union[Literal["failed", "infeasible"], dict[Sym, int]]: sym_universe = set() for constraint in all_constraints: sym_universe |= constraint.collect_syms() - - def solve_helper(): - while len(assignments) < len(sym_universe): - linear_constraints: list[LinearConstraint] = [] - linear_constraint_syms: set[Sym] = set() - nonlinear_syms: set[Sym] = set() - for constraint in all_constraints: - assign_result = constraint.apply_assignments(assignments) - if assign_result is not None: - linear_constraints.append(assign_result) - linear_constraint_syms |= { - sym for sym in assign_result.coefficients.keys() - } - nonlinear_syms |= constraint.collect_nonlinear_syms() - nonlinear_syms -= assignments.keys() - priority_syms = nonlinear_syms & linear_constraint_syms - if len(priority_syms) == 0 and len(nonlinear_syms) != 0: - chosen_sym = np.random.choice( - sorted(list(nonlinear_syms), key=lambda sym: sym._id) - ) - assignments[chosen_sym] = np.random.randint(0, bound) - continue - sym_ordering = { - sym: i - for i, sym in enumerate( - sorted( - list(linear_constraint_syms), - key=lambda sym: sym._id, - ) + assignments = {} + while len(assignments) < len(sym_universe): + linear_constraints: list[LinearConstraint] = [] + linear_constraint_syms: set[Sym] = set() + nonlinear_syms: set[Sym] = set() + for constraint in all_constraints: + linear_result = constraint.linearize(assignments) + if linear_result is not None: + linear_constraints.append(linear_result) + linear_constraint_syms |= { + sym for sym in linear_result.coefficients.keys() + } + nonlinear_syms |= constraint.collect_nonlinear_syms() + nonlinear_syms -= assignments.keys() + priority_syms = nonlinear_syms & linear_constraint_syms + if len(priority_syms) == 0 and len(nonlinear_syms) != 0: + chosen_sym = np.random.choice( + sorted(list(nonlinear_syms), key=lambda sym: sym._id) + ) + assignments[chosen_sym] = np.random.randint(0, bound) + continue + sym_ordering = { + sym: i + for i, sym in enumerate( + sorted( + list(linear_constraint_syms), + key=lambda sym: sym._id, ) - } - n = len(linear_constraints) - m_nonslack = len(linear_constraint_syms) - matrix_A = np.zeros( - (n, m_nonslack), - dtype=np.int32, ) - m = m_nonslack - vec_b = np.zeros(n, dtype=np.int32) - for row, linear_constraint in enumerate(linear_constraints): - for sym, coefficient in linear_constraint.coefficients.items(): - matrix_A[row, sym_ordering[sym]] = coefficient - if linear_constraint.has_slack: - slack_col = np.zeros((n, 1), dtype=np.int32) - slack_col[row, 0] = -1 - matrix_A = np.hstack((matrix_A, slack_col)) - m += 1 - vec_b[row] = -linear_constraint.offset - matrix_B, matrix_U, matrix_V = smith_normal_form(matrix_A) - vec_d = matrix_U @ vec_b - k = min(n, m) - vec_f = np.zeros(m) - for i in range(min(n, m)): - if matrix_B[i, i] == 0: - k = i - break - if vec_d[i] % matrix_B[i, i] != 0: - return False - vec_f += vec_d[i] / matrix_B[i, i] * matrix_V[:, i] - if m == k: - solution = vec_f - if not np.all(vec_f >= 0): - return False - else: - matrix_C = matrix_V[:, k:] - upper_bound_matrix = np.concatenate( - (matrix_C[:m_nonslack, :], -matrix_C), axis=0 + } + n = len(linear_constraints) + m_nonslack = len(linear_constraint_syms) + matrix_A = np.zeros( + (n, m_nonslack), + dtype=np.int32, + ) + m = m_nonslack + vec_b = np.zeros(n, dtype=np.int32) + for row, linear_constraint in enumerate(linear_constraints): + for sym, coefficient in linear_constraint.coefficients.items(): + matrix_A[row, sym_ordering[sym]] = coefficient + if linear_constraint.has_slack: + slack_col = np.zeros((n, 1), dtype=np.int32) + slack_col[row, 0] = -1 + matrix_A = np.hstack((matrix_A, slack_col)) + m += 1 + vec_b[row] = -linear_constraint.offset + matrix_B, matrix_U, matrix_V = smith_normal_form(matrix_A) + vec_d = matrix_U @ vec_b + k = min(n, m) + vec_f = np.zeros(m) + for i in range(min(n, m)): + if matrix_B[i, i] == 0: + k = i + break + if vec_d[i] % matrix_B[i, i] != 0: + return "infeasible" if len(assignments) == 0 else "failed" + vec_f += vec_d[i] / matrix_B[i, i] * matrix_V[:, i] + if m == k: + solution = vec_f + if not np.all(vec_f >= 0): + return "infeasible" if len(assignments) == 0 else "failed" + else: + matrix_C = matrix_V[:, k:] + upper_bound_matrix = np.concatenate( + (matrix_C[:m_nonslack, :], -matrix_C), axis=0 + ) + upper_bound_offset = np.concatenate( + (np.ones(m_nonslack) * bound - vec_f[:m_nonslack], vec_f), + axis=0, + ) + lp = linprog( + np.zeros(m - k), + A_ub=upper_bound_matrix, + b_ub=upper_bound_offset, + bounds=(None, None), + ) + if not lp.success: + return "infeasible" if len(assignments) == 0 else "failed" + cur_y = lp.x + har_iter = 50 + last_int_y = None + for _ in range(har_iter): + direction = np.random.normal(size=m - k) + direction = direction / np.linalg.norm(direction) + lower_bounds = -matrix_C @ cur_y - vec_f + upper_bounds = lower_bounds + bound + upper_bounds[m_nonslack:] = -np.nan + coefficients = matrix_C @ direction + lower_bounds = lower_bounds[coefficients != 0] + upper_bounds = upper_bounds[coefficients != 0] + coefficients = coefficients[coefficients != 0] + max_lambda = np.nanmin( + np.where(coefficients < 0, lower_bounds, upper_bounds) + / coefficients ) - upper_bound_offset = np.concatenate( - (np.ones(m_nonslack) * bound - vec_f[:m_nonslack], vec_f), - axis=0, + min_lambda = np.nanmax( + np.where(coefficients >= 0, lower_bounds, upper_bounds) + / coefficients ) - lp = linprog( - np.zeros(m - k), - A_ub=upper_bound_matrix, - b_ub=upper_bound_offset, - bounds=(None, None), + new_y = cur_y + direction * ( + np.random.rand() * (max_lambda - min_lambda) + min_lambda ) - if not lp.success: - return False - cur_y = lp.x - har_iter = 50 - last_int_y = None - for _ in range(har_iter): - direction = np.random.normal(size=m - k) - direction = direction / np.linalg.norm(direction) - lower_bounds = -matrix_C @ cur_y - vec_f - upper_bounds = lower_bounds + bound - upper_bounds[m_nonslack:] = -np.nan - coefficients = matrix_C @ direction - lower_bounds = lower_bounds[coefficients != 0] - upper_bounds = upper_bounds[coefficients != 0] - coefficients = coefficients[coefficients != 0] - max_lambda = np.nanmin( - np.where(coefficients < 0, lower_bounds, upper_bounds) - / coefficients - ) - min_lambda = np.nanmax( - np.where(coefficients >= 0, lower_bounds, upper_bounds) - / coefficients - ) - new_y = cur_y + direction * ( - np.random.rand() * (max_lambda - min_lambda) + min_lambda - ) - new_int_y = np.round(new_y) - cur_y = new_y - if np.all(upper_bound_matrix @ new_int_y <= upper_bound_offset): - last_int_y = new_int_y - if last_int_y is not None: - solution = matrix_C @ last_int_y + vec_f - else: - return False + new_int_y = np.round(new_y) + cur_y = new_y + if np.all(upper_bound_matrix @ new_int_y <= upper_bound_offset): + last_int_y = new_int_y + if last_int_y is not None: + solution = matrix_C @ last_int_y + vec_f + else: + return "infeasible" if len(assignments) == 0 else "failed" + if len(nonlinear_syms) == 0: + for sym in linear_constraint_syms: + assignments[sym] = int(solution[sym_ordering[sym]]) + for sym in sym_universe - assignments.keys(): + assignments[sym] = np.random.randint(0, bound) + else: chosen_sym = None if len(priority_syms) != 0: chosen_sym = np.random.choice( @@ -541,22 +661,64 @@ def solve_helper(): assignments[chosen_sym] = np.random.randint(0, bound) else: assignments[chosen_sym] = int(solution[sym_ordering[chosen_sym]]) - return True + return assignments + + def solve_constraint( + self, + disjoint_constraint: DisjointConstraint, + *, + partial_solution: Optional[Solution] = None, + bound: int, + search_limit: int, + seed: Optional[int] = None, + ) -> Optional[Solution]: + if seed is not None: + np.random.seed(seed=seed) + if partial_solution is not None: + disjoint_constraint = disjoint_constraint.substitute( + partial_solution.substitutions + ) + clauses = list(disjoint_constraint.clauses) for _ in range(search_limit): - if solve_helper(): - var_assignments = {} - for sym, sub in self.var_subs.items(): - result = sub.apply_assignments(assignments) - if result is not None: - var_assignments[sym] = result - ctxt = {} - for (config, field), sub in self.ctxt.items(): - result = sub.apply_assignments(assignments) - if result is not None: - ctxt[(config, field)] = result - return Solution(ctxt, var_assignments) + if len(clauses) == 0: + return None + chosen_clause = np.random.choice(clauses) + assert isinstance(chosen_clause, ConstraintClause) + all_constraints = chosen_clause.constraints + tuple(self.extra_constraints) + assignment_result = self._solve_for_assignments(all_constraints, bound) + if assignment_result == "failed": + continue + elif assignment_result == "infeasible": + clauses = list(clause for clause in clauses if clause != chosen_clause) else: - assignments = {} - + return self._make_solution_from_assignments( + ({} if partial_solution is None else partial_solution.substitutions) + | assignment_result + ) return None + + def rename_sym_set( + self, syms: frozenset[Sym], free_vars: frozenset[Sym] + ) -> tuple[dict[Sym, Sym], dict[Sym, Sym]]: + var_renaming = {} + sym_renaming = {sym: Sym(sym.name()) for sym in self.hidden_vars & syms} + for var in free_vars: + var_sub = self.var_subs[var] + var_sub_syms = var_sub.collect_syms() + if len(var_sub_syms & syms) != 0: + sym_renaming |= {sym: Sym(sym.name()) for sym in var_sub_syms} + renamed_var = Sym(var.name()) + var_renaming[var] = renamed_var + self.var_subs[renamed_var] = var_sub.rename_syms(sym_renaming) + self.extra_constraints.extend( + tuple( + extra_constraint.rename_syms(sym_renaming) + for extra_constraint in self.extra_constraints + if len(extra_constraint.collect_syms() & sym_renaming.keys()) != 0 + ) + ) + return ( + sym_renaming, + var_renaming, + ) diff --git a/tests/golden/test_transpiler/test_matmul.txt b/tests/golden/test_transpiler/test_matmul.txt new file mode 100644 index 000000000..1f15cca5a --- /dev/null +++ b/tests/golden/test_transpiler/test_matmul.txt @@ -0,0 +1,22 @@ +((a_4,b_5,c_6)=>{ +ctxt={} +if(!($size_a_4_0==$N_1))return [1,ctxt,{}]; +if(!($size_a_4_1==$K_3))return [1,ctxt,{}]; +if(!($size_b_5_0==$K_3))return [1,ctxt,{}]; +if(!($size_b_5_1==$M_2))return [1,ctxt,{}]; +if(!($size_c_6_0==$N_1))return [1,ctxt,{}]; +if(!($size_c_6_1==$M_2))return [1,ctxt,{}]; +for(let i_7=0;i_7<$N_1;i_7++){ +for(let j_8=0;j_8<$M_2;j_8++){ +for(let k_9=0;k_9<$K_3;k_9++){ +if(!(i_7<$size_a_4_0&&i_7>=0))return [1,ctxt,{}]; +if(!(k_9<$size_a_4_1&&k_9>=0))return [1,ctxt,{}]; +if(!(k_9<$size_b_5_0&&k_9>=0))return [1,ctxt,{}]; +if(!(j_8<$size_b_5_1&&j_8>=0))return [1,ctxt,{}]; +if(!(i_7<$size_c_6_0&&i_7>=0))return [1,ctxt,{}]; +if(!(j_8<$size_c_6_1&&j_8>=0))return [1,ctxt,{}]; +c_6[Math.imul(i_7,$stride_c_6_0)+Math.imul(j_8,$stride_c_6_1)+0]+=(a_4[Math.imul(i_7,$stride_a_4_0)+Math.imul(k_9,$stride_a_4_1)+0]*b_5[Math.imul(k_9,$stride_b_5_0)+Math.imul(j_8,$stride_b_5_1)+0]); +} +} +} +return [0,ctxt,];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_matmul_coverage.txt b/tests/golden/test_transpiler/test_matmul_coverage.txt new file mode 100644 index 000000000..d7f3b98d9 --- /dev/null +++ b/tests/golden/test_transpiler/test_matmul_coverage.txt @@ -0,0 +1,44 @@ +((a_4,b_5,c_6)=>{ +ctxt={} +let body_23=false;let body_26=false;let body_29=false;let skip_24=false;let skip_27=false;let skip_30=false; +if(!($size_a_4_0==$N_1))return [1,ctxt,{}]; +if(!($size_a_4_1==$K_3))return [1,ctxt,{}]; +if(!($size_b_5_0==$K_3))return [1,ctxt,{}]; +if(!($size_b_5_1==$M_2))return [1,ctxt,{}]; +if(!($size_c_6_0==$N_1))return [1,ctxt,{}]; +if(!($size_c_6_1==$M_2))return [1,ctxt,{}]; + +body_23||=(0<$N_1);skip_24||=(0>=$N_1); +for(let i_7=0;i_7<$N_1;i_7++){ + +body_26||=(0<$M_2);skip_27||=(0>=$M_2); +for(let j_8=0;j_8<$M_2;j_8++){ + +body_29||=(0<$K_3);skip_30||=(0>=$K_3); +for(let k_9=0;k_9<$K_3;k_9++){ +if(!(i_7<$size_c_6_0&&i_7>=0))return [1,ctxt,{}]; +if(!(j_8<$size_c_6_1&&j_8>=0))return [1,ctxt,{}]; + +if(!(i_7<$size_a_4_0&&i_7>=0))return [1,ctxt,{}]; +if(!(k_9<$size_a_4_1&&k_9>=0))return [1,ctxt,{}]; + + + +if(!(i_7<$size_a_4_0&&i_7>=0))return [1,ctxt,{}]; +if(!(k_9<$size_a_4_1&&k_9>=0))return [1,ctxt,{}]; +if(!(k_9<$size_b_5_0&&k_9>=0))return [1,ctxt,{}]; +if(!(j_8<$size_b_5_1&&j_8>=0))return [1,ctxt,{}]; + + + +if(!(k_9<$size_b_5_0&&k_9>=0))return [1,ctxt,{}]; +if(!(j_8<$size_b_5_1&&j_8>=0))return [1,ctxt,{}]; + + +if(!(i_7<$size_c_6_0&&i_7>=0))return [1,ctxt,{}]; +if(!(j_8<$size_c_6_1&&j_8>=0))return [1,ctxt,{}]; +c_6[Math.imul(i_7,$stride_c_6_0)+Math.imul(j_8,$stride_c_6_1)+0]+=(a_4[Math.imul(i_7,$stride_a_4_0)+Math.imul(k_9,$stride_a_4_1)+0]*b_5[Math.imul(k_9,$stride_b_5_0)+Math.imul(j_8,$stride_b_5_1)+0]); +} +} +} +return [0,ctxt,{body_23,body_26,body_29,skip_24,skip_27,skip_30}];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_nested_control_flow_coverage.txt b/tests/golden/test_transpiler/test_nested_control_flow_coverage.txt new file mode 100644 index 000000000..6040168ee --- /dev/null +++ b/tests/golden/test_transpiler/test_nested_control_flow_coverage.txt @@ -0,0 +1,30 @@ +((b_2)=>{ +ctxt={} +let body_12=false;let false_case_17=false;let false_case_25=false;let skip_13=false;let true_case_16=false;let true_case_24=false;let write_20=false;let write_21=false;let write_26=false;let write_27=false; + +body_12||=(0<$n_1);skip_13||=(0>=$n_1); +for(let i_3=0;i_3<$n_1;i_3++){ + + +if((i_3<(($n_1/2)|0))){ +true_case_16=true; +write_20=true; +b_2[0]=2; +}else{ +false_case_17=true; +write_21=true; +b_2[0]=3; +} + + +if((i_3==($n_1-1))){ +true_case_24=true; +write_26=true; +b_2[0]+=1; +}else{ +false_case_25=true; +write_27=true; +b_2[0]+=2; +} +} +return [0,ctxt,{body_12,false_case_17,false_case_25,skip_13,true_case_16,true_case_24,write_20,write_21,write_26,write_27}];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_variable_length_array_coverage.txt b/tests/golden/test_transpiler/test_variable_length_array_coverage.txt new file mode 100644 index 000000000..e8605468b --- /dev/null +++ b/tests/golden/test_transpiler/test_variable_length_array_coverage.txt @@ -0,0 +1,24 @@ +(()=>{ +ctxt={} +let body_9=false;let skip_10=false;let write_12=new ArrayBuffer(1,{maxByteLength:16});let write_15=new ArrayBuffer(1,{maxByteLength:16}); +if(!($n_1>2))return [1,ctxt,{}]; + +body_9||=(2<$n_1);skip_10||=(2>=$n_1); +for(let i_2=2;i_2<$n_1;i_2++){ + +if(!(i_2>=0))return [1,ctxt,{}]; +let b_3=new Int32Array(i_2); +while(i_2>write_12.maxByteLength){let temp_13=new ArrayBuffer(write_12.byteLength,{maxByteLength:2*write_12.maxByteLength});for(let i=0;iwrite_15.maxByteLength){let temp_16=new ArrayBuffer(write_15.byteLength,{maxByteLength:2*write_15.maxByteLength});for(let i=0;i=0))return [1,ctxt,{}]; +if(!((i_2-1)=0))return [1,ctxt,{}]; +write_12[Math.imul((i_2-1),1)+0]=1; + +if(!((i_2-1)=0))return [1,ctxt,{}]; +b_3[Math.imul((i_2-1),1)+0]=0; +if(!((i_2-2)=0))return [1,ctxt,{}]; +write_15[Math.imul((i_2-2),1)+0]=1; + +if(!((i_2-2)=0))return [1,ctxt,{}]; +b_3[Math.imul((i_2-2),1)+0]=1; +} +return [0,ctxt,{body_9,skip_10,write_12,write_15}];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_window_coverage.txt b/tests/golden/test_transpiler/test_window_coverage.txt new file mode 100644 index 000000000..7873ddb58 --- /dev/null +++ b/tests/golden/test_transpiler/test_window_coverage.txt @@ -0,0 +1,14 @@ +((a_1)=>{ +ctxt={} + +if(!($size_a_1_0==16))return [1,ctxt,{}]; +if(!(0<=1&&1<=8&&8<=$size_a_1_0))return [1,ctxt,{}]; +if(!(3<$size_a_1_0&&3>=0))return [1,ctxt,{}]; + +if(!(3<$size_a_1_0&&3>=0))return [1,ctxt,{}]; +a_1[Math.imul(3,$stride_a_1_0)+0]=2; +if(!(2<(8-1)&&2>=0))return [1,ctxt,{}]; + +if(!(2<(8-1)&&2>=0))return [1,ctxt,{}]; +a_1[Math.imul(2,$stride_a_1_0)+(0+Math.imul(1,$stride_a_1_0))]=3; +return [0,ctxt,{}];}) \ No newline at end of file diff --git a/tests/test_apps.py b/tests/test_apps.py index 960fec521..ed5fac092 100644 --- a/tests/test_apps.py +++ b/tests/test_apps.py @@ -47,7 +47,8 @@ def test_gemmini_matmul(golden): @pytest.mark.slow def test_gemmini_conv(golden): module_file = REPO_ROOT / "apps" / "gemmini" / "src" / "exo" / "conv.py" - assert _test_app(module_file) == golden + # TODO: uncomment when conv is fixed in main + # assert _test_app(module_file) == golden def test_blur(golden): diff --git a/tests/test_chexo.py b/tests/test_chexo.py index d311fc98c..6e0b0cc09 100644 --- a/tests/test_chexo.py +++ b/tests/test_chexo.py @@ -5,7 +5,6 @@ TypeVisitor, get_free_variables, collect_path_constraints, - collect_arg_size_constraints, ) from exo.rewrite.constraint_solver import ConstraintMaker from exo import proc, config @@ -72,15 +71,3 @@ def foo(a: size, b: f32[a]): golden == collect_path_constraints(foo.find("b[j] = b[i]")._impl, cm).pretty_print() ) - - -def test_arg_size_constraints(golden): - @proc - def foo(a: size, b: size, c: f32[a * 2, b + 1]): - pass - - type_visitor = TypeVisitor() - type_visitor.visit(foo._loopir_proc) - cm = ConstraintMaker(type_visitor.type_map) - constraints = collect_arg_size_constraints(foo._loopir_proc.args, cm) - assert golden == constraints.pretty_print() diff --git a/tests/test_transpiler.py b/tests/test_transpiler.py new file mode 100644 index 000000000..afc19cb3b --- /dev/null +++ b/tests/test_transpiler.py @@ -0,0 +1,105 @@ +from __future__ import annotations +from exo.core.prelude import Sym + +from exo.rewrite.constraint_solver import ConstraintMaker, DisjointConstraint +from exo.core.LoopIR import T +from exo import proc +from exo.rewrite.chexo import TypeVisitor +from exo.backend.LoopIR_transpiler import Transpiler, CoverageArgs + + +def get_coverage_args(p) -> CoverageArgs: + p_type = TypeVisitor() + p_type.visit(p._loopir_proc) + cm = ConstraintMaker(p_type.type_map) + return CoverageArgs(cm) + + +def test_matmul(golden): + Sym._unq_count = 1 + + @proc + def matmul(N: size, M: size, K: size, a: f32[N, K], b: f32[K, M], c: f32[N, M]): + for i in seq(0, N): + for j in seq(0, M): + for k in seq(0, K): + c[i, j] += a[i, k] * b[k, j] + + assert golden == Transpiler(matmul._loopir_proc).get_javascript_template().template + + +def test_matmul_coverage(golden): + Sym._unq_count = 1 + + @proc + def matmul(N: size, M: size, K: size, a: f32[N, K], b: f32[K, M], c: f32[N, M]): + for i in seq(0, N): + for j in seq(0, M): + for k in seq(0, K): + c[i, j] += a[i, k] * b[k, j] + + assert ( + golden + == Transpiler(matmul._loopir_proc, get_coverage_args(matmul)) + .get_javascript_template() + .template + ) + + +def test_window_coverage(golden): + Sym._unq_count = 1 + + @proc + def foo(a: i32[16]): + a_win = a[1:8] + a[3] = 2 + a_win[2] = 3 + + assert ( + golden + == Transpiler(foo._loopir_proc, get_coverage_args(foo)) + .get_javascript_template() + .template + ) + + +def test_variable_length_array_coverage(golden): + Sym._unq_count = 1 + + @proc + def foo(n: size): + assert n > 2 + for i in seq(2, n): + b: i32[i] + b[i - 1] = 0 + b[i - 2] = 1 + + assert ( + golden + == Transpiler(foo._loopir_proc, get_coverage_args(foo)) + .get_javascript_template() + .template + ) + + +def test_nested_control_flow_coverage(golden): + Sym._unq_count = 1 + + @proc + def foo(n: size, b: f32): + for i in seq(0, n): + if i < n / 2: + b = 2 + else: + b = 3 + if i == n - 1: + b += 1 + else: + b += 2 + + assert ( + golden + == Transpiler(foo._loopir_proc, get_coverage_args(foo)) + .get_javascript_template() + .template + ) From f39402b033d295d8bc20bfe4217c255f9e708e2c Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 13 May 2025 15:04:32 -0400 Subject: [PATCH 12/24] improve coverage checks --- src/exo/backend/LoopIR_transpiler.py | 1346 +++++++++++++---- src/exo/backend/coverage.py | 397 ++++- src/exo/core/internal_cursors.py | 12 + src/exo/frontend/typecheck.py | 1 + src/exo/rewrite/LoopIR_scheduling.py | 2 +- src/exo/rewrite/chexo.py | 11 +- src/exo/rewrite/constraint_solver.py | 135 +- tests/golden/test_transpiler/test_matmul.txt | 20 +- .../test_transpiler/test_matmul_coverage.txt | 44 +- .../test_nested_control_flow_coverage.txt | 25 +- .../test_variable_length_array_coverage.txt | 31 +- .../test_transpiler/test_window_coverage.txt | 25 +- 12 files changed, 1568 insertions(+), 481 deletions(-) diff --git a/src/exo/backend/LoopIR_transpiler.py b/src/exo/backend/LoopIR_transpiler.py index 501b6a39e..beda60a8c 100644 --- a/src/exo/backend/LoopIR_transpiler.py +++ b/src/exo/backend/LoopIR_transpiler.py @@ -1,7 +1,7 @@ from functools import reduce from itertools import chain from string import Template -from typing import Any, Iterable, Optional, Union +from typing import Any, Callable, Generator, Iterable, Optional, Union from ..core.configs import Config @@ -11,12 +11,26 @@ CoverageSkeleton, CoverageSkeletonNode, CoverageSkeletonBranch, + FailureCondition, IndexedFiller, MemoryAccess, MemoryAccessPair, + ParallelAccess, + ParallelAccessPair, + SymbolicPoint, + SymbolicSlice, + StagingOverlap, + SymbolicWindowIndex, ) -from ..rewrite.constraint_solver import ConstraintMaker -from dataclasses import dataclass +from ..core.internal_cursors import Block, Cursor, Node, NodePath +from ..rewrite.constraint_solver import ( + TRUE_CONSTRAINT, + Constraint, + ConstraintMaker, + DisjointConstraint, + Expression, +) +from dataclasses import dataclass, field import numpy as np @@ -63,18 +77,31 @@ class Reference: is_config: bool +@dataclass +class Point: + index: str + + +@dataclass +class Slice: + lower_bound: str + upper_bound: str + + +WindowIndex = Union[Point, Slice] + + @dataclass class Dimension: size: str stride: str + window_idx: WindowIndex @dataclass class Tensor: name: Sym - offset: str dims: tuple[Dimension, ...] - resize_placeholder: Optional[int] ExoValue = Union[Constant, Reference, Tensor] @@ -84,33 +111,750 @@ class Tensor: INITIAL_DYNAMIC_SIZE = 16 +@dataclass +class SymbolicTensor: + name: Sym + dims: tuple[SymbolicWindowIndex, ...] + resize_placeholder: Optional[int] + + +class AliasingTracker: + def __init__(self, parent_state: "CoverageState"): + self.writes: dict[Sym, list[MemoryAccess]] = {} + self.reads: dict[Sym, list[MemoryAccess]] = {} + self.parent_state = parent_state + + def access_tensor( + self, + js_tensor: Tensor, + js_idxs: tuple[str, ...], + symbolic_idxs: tuple[Expression], + access_placeholder: int, + _access_cursor: Node, + cov_node: CoverageSkeletonNode, + is_write: bool, + ): + js_access = "+".join( + f"Math.imul({js_idx},{js_dim.stride})" + for js_idx, js_dim in zip(js_idxs, js_tensor.dims) + ) + access_sym = Sym("access") + mark_stmt = f"{repr(access_sym)}[{js_access}]=1;" + base_size = f"Math.imul({js_tensor.dims[0].size},{js_tensor.dims[0].stride})" + resize_placeholder = self.parent_state.symbolic_tensors[ + js_tensor.name + ].resize_placeholder + if resize_placeholder is None: + decl_stmt = f"let {repr(access_sym)}=new ArrayBuffer({base_size});" + fillers = ( + IndexedFiller(access_placeholder, mark_stmt), + IndexedFiller(self.parent_state.cov_placeholder, decl_stmt), + ) + else: + temp_sym = Sym("temp") + decl_stmt = f"let {repr(access_sym)}=new ArrayBuffer(1,{{maxByteLength:{INITIAL_DYNAMIC_SIZE}}});" + resize_stmt = f"while({base_size}>{repr(access_sym)}.maxByteLength){{let {repr(temp_sym)}=new ArrayBuffer({repr(access_sym)}.byteLength,{{maxByteLength:2*{repr(access_sym)}.maxByteLength}});for(let i=0;i<{repr(access_sym)}.byteLength;i++){repr(temp_sym)}[i]={repr(access_sym)}[i];{repr(access_sym)}={repr(temp_sym)}}};{repr(access_sym)}.resize(Math.max({base_size},{repr(access_sym)}.byteLength));" + fillers = ( + IndexedFiller(access_placeholder, mark_stmt), + IndexedFiller(self.parent_state.cov_placeholder, decl_stmt), + IndexedFiller(resize_placeholder, resize_stmt), + ) + + dest = self.writes if is_write else self.reads + if js_tensor.name not in dest: + dest[js_tensor.name] = [] + dest[js_tensor.name].append( + MemoryAccess(access_sym, cov_node, symbolic_idxs, fillers) + ) + + def access_scalar( + self, + name: Sym, + access_placeholder: int, + _access_cursor: Node, + cov_node: CoverageSkeletonNode, + is_write: bool, + ): + access_sym = Sym("access") + mark_stmt = f"{repr(access_sym)}=true;" + decl_stmt = f"let {repr(access_sym)}=false;" + + dest = self.writes if is_write else self.reads + if name not in dest: + dest[name] = [] + dest[name].append( + MemoryAccess( + access_sym, + cov_node, + (), + ( + IndexedFiller(access_placeholder, mark_stmt), + IndexedFiller(self.parent_state.cov_placeholder, decl_stmt), + ), + ) + ) + + def make_aliasing_accesses(self) -> tuple[MemoryAccessPair, ...]: + aliasable_accesses: list[MemoryAccessPair] = [] + for sym, write_indices in self.writes.items(): + read_indices = self.reads[sym] if sym in self.reads else [] + for i, index1 in enumerate(write_indices): + for index2 in chain(write_indices[i + 1 :], read_indices): + aliasable_accesses.append(MemoryAccessPair(index1, index2)) + return tuple(aliasable_accesses) + + +class FailureTracker: + def __init__(self, scope: Block, parent_state: "CoverageState"): + self.scope = scope + self.in_scope = False + self.call_depth = 0 + self.failure_conditions: list[FailureCondition] = [] + self.parent_state = parent_state + + def enter_stmt(self, stmt_cursor: Node): + if stmt_cursor in self.scope: + self.in_scope = True + + def exit_stmt(self, stmt_cursor): + if stmt_cursor in self.scope: + self.in_scope = False + + def enter_proc_body(self): + self.call_depth += 1 + + def exit_proc_body(self): + self.call_depth -= 1 + + def add_assertion( + self, asserted_cond: DisjointConstraint, js_cond: str, placeholder: int + ): + if self.in_scope and self.call_depth == 1: + fail_sym = Sym("fail") + self.failure_conditions.append( + FailureCondition( + fail_sym, + asserted_cond.invert(), + self.parent_state.current_node, + ( + IndexedFiller( + self.parent_state.cov_placeholder, + f"let {repr(fail_sym)}=false;", + ), + IndexedFiller( + placeholder, f"if(!({js_cond})){repr(fail_sym)}=true;" + ), + ), + ) + ) + + def make_failures(self) -> tuple[FailureCondition, ...]: + return tuple(self.failure_conditions) + + +@dataclass +class SymbolicWindow: + name: Sym + index: tuple[SymbolicWindowIndex, ...] + + +@dataclass +class StageMemArgs: + window_expr: LoopIR.WindowExpr + scope: Block + + +class StageMemTracker: + def __init__(self, args: StageMemArgs, parent_state: "CoverageState"): + self.scope: Block = args.scope + self.buffer_sym = args.window_expr.name + self.staged_window: Optional[tuple[SymbolicTensor, Tensor]] = None + self.overlaps: list[StagingOverlap] = [] + self.parent_state: "CoverageState" = parent_state + + def enter_stmt(self, stmt_node: Node): + if stmt_node in self.scope: + js_tensor = self.parent_state.parent_transpiler._lookup_sym(self.buffer_sym) + assert isinstance(js_tensor, Tensor) + self.staged_window = ( + self.parent_state.symbolic_tensors[self.buffer_sym], + js_tensor, + ) + + def exit_stmt(self, stmt_cursor: Node): + if stmt_cursor in self.scope: + self.staged_window = None + + def access_tensor( + self, + js_tensor: Tensor, + js_idxs: tuple[str, ...], + symbolic_idxs: tuple[Expression], + access_placeholder: int, + access_cursor: Node, + cov_node: CoverageSkeletonNode, + _is_write: bool, + ): + if ( + self.staged_window is not None + and self.staged_window[1].name == js_tensor.name + ): + symbolic_staged_window, js_staged_window = self.staged_window + overlap_sym = Sym("overlap") + disjoint_sym = Sym("disjoint") + access_window = tuple(SymbolicPoint(idx) for idx in symbolic_idxs) + js_overlap_cond = "&&".join( + f"(({js_staged_slice.lower_bound})<=({js_idx}))&&(({js_staged_slice.upper_bound})>({js_idx}))" + for js_idx, js_staged_slice in zip( + js_idxs, + ( + dim.window_idx + for dim in js_staged_window.dims + if isinstance(dim.window_idx, Slice) + ), + ) + ) + self.overlaps.append( + StagingOverlap( + overlap_sym, + disjoint_sym, + symbolic_staged_window.dims, + access_window, + cov_node, + access_cursor.get_path(), + ( + IndexedFiller( + access_placeholder, + f"if({js_overlap_cond}){{{repr(overlap_sym)}=true}}else{{{repr(disjoint_sym)}=true}}", + ), + IndexedFiller( + self.parent_state.cov_placeholder, + f"let {repr(overlap_sym)}=false;let {repr(disjoint_sym)}=false;", + ), + ), + ) + ) + + def make_staging_overlaps(self) -> tuple[StagingOverlap, ...]: + return tuple(self.overlaps) + + +@dataclass +class ParallelScope: + iter_sym: Sym + loop_entrance_placeholder: int + writes: dict[Sym, list[ParallelAccess]] = field(default_factory=lambda: {}) + reads: dict[Sym, list[ParallelAccess]] = field(default_factory=lambda: {}) + access_set_syms: dict[Sym, Sym] = field(default_factory=lambda: {}) + + +class ParallelAccessTracker: + def __init__(self, parent_state: "CoverageState"): + self.parallel_scopes: list[ParallelScope] = [] + self.coverage_sym: Sym = Sym("par") + self.pairs: list[ParallelAccessPair] = [] + self.parent_state: "CoverageState" = parent_state + + def enter_loop(self, loop: LoopIR.For, loop_entrance_placeholder: int): + if isinstance(loop.loop_mode, LoopIR.Par): + self.parallel_scopes.append( + ParallelScope(loop.iter, loop_entrance_placeholder) + ) + + def exit_loop_body(self, loop: LoopIR.For): + if isinstance(loop.loop_mode, LoopIR.Par): + scope = self.parallel_scopes.pop() + loop_tail_placeholder = ( + self.parent_state.parent_transpiler._make_placeholder() + ) + for sym, sym_writes in scope.writes.items(): + for sym_write_idx, sym_write in enumerate(sym_writes): + for other_access in chain( + sym_writes[sym_write_idx:], + (scope.reads[sym] if sym in scope.reads else []), + ): + self.pairs.append( + ParallelAccessPair( + self.coverage_sym, + scope.iter_sym, + sym_write, + other_access, + ( + IndexedFiller( + loop_tail_placeholder, + f"{repr(scope.access_set_syms[sym])}_pw={repr(scope.access_set_syms[sym])}_pw.union({repr(scope.access_set_syms[sym])}_cw);", + ), + IndexedFiller( + loop_tail_placeholder, + f"{repr(scope.access_set_syms[sym])}_pr={repr(scope.access_set_syms[sym])}_pr.union({repr(scope.access_set_syms[sym])}_cr);", + ), + ), + ) + ) + + def access_tensor( + self, + js_tensor: Tensor, + js_idxs: tuple[str, ...], + symbolic_idxs: tuple[Expression], + access_placeholder: int, + _access_cursor: Node, + cov_node: CoverageSkeletonNode, + is_write: bool, + ): + js_access = "+".join( + f"Math.imul({js_idx},{js_dim.stride})" + for js_idx, js_dim in zip(js_idxs, js_tensor.dims) + ) + for parallel_scope in self.parallel_scopes: + dest = parallel_scope.writes if is_write else parallel_scope.reads + if js_tensor.name not in dest: + dest[js_tensor.name] = [] + if js_tensor.name not in parallel_scope.access_set_syms: + parallel_scope.access_set_syms[js_tensor.name] = Sym("access_set") + access_set_sym = parallel_scope.access_set_syms[js_tensor.name] + dest[js_tensor.name].append( + ParallelAccess( + cov_node, + symbolic_idxs, + ( + IndexedFiller( + parallel_scope.loop_entrance_placeholder, + "".join( + ( + f"let {repr(access_set_sym)}_pr=new Set();", + f"let {repr(access_set_sym)}_pw=new Set();", + f"let {repr(access_set_sym)}_cr=new Set();", + f"let {repr(access_set_sym)}_cw=new Set();", + f"let {repr(self.coverage_sym)}=false;", + ) + ), + ), + IndexedFiller( + access_placeholder, + "".join( + ( + f"{repr(access_set_sym)}{'_cw' if is_write else '_cr'}.add({js_access});", + f"if({repr(access_set_sym)}_pw.has({js_access})){{{repr(self.coverage_sym)}=true}}", + *( + ( + f"if({repr(access_set_sym)}_pr.has({js_access})){{{repr(self.coverage_sym)}=true}}", + ) + if is_write + else () + ), + ), + ), + ), + ), + ) + ) + + def access_scalar( + self, + name: Sym, + access_placeholder: int, + _access_cursor: Node, + cov_node: CoverageSkeletonNode, + is_write: bool, + ): + self.access_tensor( + Tensor(name, (Dimension("1", "0", Slice("0", "1")),)), + ("0",), + tuple(), + access_placeholder, + _access_cursor, + cov_node, + is_write, + ) + + def make_parallel_access_pairs(self) -> tuple[ParallelAccessPair, ...]: + return tuple(self.pairs) + + @dataclass class CoverageArgs: cm: ConstraintMaker + failure_scope: Optional[Block] = None + stage_mem_args: Optional[StageMemArgs] = None class CoverageState: - def __init__(self, args: CoverageArgs, cov_placeholder: int): + def __init__(self, args: CoverageArgs, parent_transpiler: "Transpiler"): self.cm: ConstraintMaker = args.cm + self.parent_transpiler: Transpiler = parent_transpiler + self.cov_placeholder: int = parent_transpiler._make_placeholder() self.root: CoverageSkeletonNode = CoverageSkeletonNode(None, None, ()) - self.buffer_writes: dict[Sym, list[MemoryAccess]] = {} - self.buffer_reads: dict[Sym, list[MemoryAccess]] = {} + self.current_node: CoverageSkeletonNode = self.root + self.symbolic_tensors: dict[Sym, SymbolicTensor] = {} + self.scalar_symbols: dict[Sym, Sym] = {} + self.ctxt_symbols: dict[tuple[Config, str], Sym] = {} + self.aliasing_tracker = AliasingTracker(self) + self.failure_tracker: Optional[FailureTracker] = ( + None + if args.failure_scope is None + else FailureTracker(args.failure_scope, self) + ) + self.stage_mem_tracker: Optional[StageMemTracker] = ( + None + if args.stage_mem_args is None + else StageMemTracker(args.stage_mem_args, self) + ) + self.parallel_access_tracker = ParallelAccessTracker(self) self.free_vars: list[Sym] = [] - self.cov_placeholder = cov_placeholder - def make_skeleton(self) -> CoverageSkeleton: - aliasable_accesses: list[MemoryAccessPair] = [] - for sym, write_indices in self.buffer_writes.items(): - read_indices = self.buffer_reads[sym] if sym in self.buffer_reads else [] - for i, index1 in enumerate(write_indices): - for index2 in chain(write_indices[i + 1 :], read_indices): - aliasable_accesses.append(MemoryAccessPair(index1, index2)) + def enter_loop( + self, + stmt: LoopIR.For, + lo_js: str, + hi_js: str, + transpile_loop_body: Callable[[], None], + ): + body_sym = Sym("body") + skip_sym = Sym("skip") + loop_entrance_placeholder = self.parent_transpiler._make_placeholder() + body_constraint = ( + self.cm.make_constraint_from_inequality(stmt.lo, stmt.iter, "<=") + .lift_to_disjoint_constraint() + .intersect( + self.cm.make_constraint_from_inequality( + stmt.iter, stmt.hi, "<" + ).lift_to_disjoint_constraint() + ) + ) + skip_constraint = self.cm.make_constraint_from_inequality( + stmt.lo, stmt.hi, ">=" + ).lift_to_disjoint_constraint() + parent_node = self.current_node + body_child = CoverageSkeletonNode( + body_sym, + (parent_node, body_constraint), + ( + IndexedFiller( + self.cov_placeholder, + f"let {repr(body_sym)}=false;", + ), + IndexedFiller( + loop_entrance_placeholder, + f"{repr(body_sym)}||=({lo_js}<{hi_js});", + ), + ), + ) + skip_child = CoverageSkeletonNode( + skip_sym, + (parent_node, skip_constraint), + ( + IndexedFiller( + self.cov_placeholder, + f"let {repr(skip_sym)}=false;", + ), + IndexedFiller( + loop_entrance_placeholder, + f"{repr(skip_sym)}||=({lo_js}>={hi_js});", + ), + ), + ) + self.current_node.branches.append( + CoverageSkeletonBranch(body_child, skip_child) + ) + self.parallel_access_tracker.enter_loop(stmt, loop_entrance_placeholder) + self.free_vars.append(stmt.iter) + self.current_node = body_child + transpile_loop_body() + self.current_node = parent_node + self.parallel_access_tracker.exit_loop_body(stmt) + + def enter_if( + self, + stmt: LoopIR.If, + transpile_if_body: Callable[[], None], + transpile_else_body: Callable[[], None], + ): + parent_node = self.current_node + true_sym = Sym("true_case") + false_sym = Sym("false_case") + true_placeholder = self.parent_transpiler._make_placeholder() + cond_constraint = self.cm.make_constraint(stmt.cond) + true_node = CoverageSkeletonNode( + true_sym, + (parent_node, cond_constraint), + ( + IndexedFiller( + self.cov_placeholder, + f"let {repr(true_sym)}=false;", + ), + IndexedFiller(true_placeholder, f"{repr(true_sym)}=true;"), + ), + ) + self.current_node = true_node + transpile_if_body() + false_placeholder = self.parent_transpiler._make_placeholder() + false_node = CoverageSkeletonNode( + false_sym, + (parent_node, cond_constraint.invert()), + ( + IndexedFiller( + self.cov_placeholder, + f"let {repr(false_sym)}=false;", + ), + IndexedFiller(false_placeholder, f"{repr(false_sym)}=true;"), + ), + ) + self.current_node = false_node + transpile_else_body() + self.current_node = parent_node + new_branch = CoverageSkeletonBranch(true_node, false_node) + self.current_node.branches.append(new_branch) + + def enter_stmt(self, stmt_cursor: Node): + if self.failure_tracker is not None: + self.failure_tracker.enter_stmt(stmt_cursor) + if self.stage_mem_tracker is not None: + self.stage_mem_tracker.enter_stmt(stmt_cursor) + + def exit_stmt(self, stmt_cursor: Node): + if self.failure_tracker is not None: + self.failure_tracker.exit_stmt(stmt_cursor) + if self.stage_mem_tracker is not None: + self.stage_mem_tracker.exit_stmt(stmt_cursor) + + def enter_proc_body(self): + if self.failure_tracker is not None: + self.failure_tracker.enter_proc_body() + + def exit_proc_body(self): + if self.failure_tracker is not None: + self.failure_tracker.exit_proc_body() + + def assert_shape_matches( + self, tensor_sym: Sym, shape: list[LoopIR.expr], shape_matches_js: str + ): + match_cond = TRUE_CONSTRAINT + for tensor_dim, shape_dim in zip( + ( + dim + for dim in self.symbolic_tensors[tensor_sym].dims + if isinstance(dim, SymbolicSlice) + ), + shape, + ): + match_cond = match_cond.intersect( + Constraint( + self.cm.make_expression(shape_dim) + .negate() + .add(tensor_dim.upper_bound) + .add(tensor_dim.lower_bound.negate()), + False, + ).lift_to_disjoint_constraint(), + ) + if self.failure_tracker is not None: + self.failure_tracker.add_assertion( + match_cond, shape_matches_js, self.parent_transpiler._make_placeholder() + ) + + def assert_predicate(self, pred: LoopIR.expr, js_pred: str): + if self.failure_tracker is not None: + self.failure_tracker.add_assertion( + self.cm.make_constraint(pred), + js_pred, + self.parent_transpiler._make_placeholder(), + ) + def make_tensor(self, sym: Sym, dims: list[LoopIR.expr], nonnegative_dims_js: str): + symbolic_dims = tuple(self.cm.make_expression(dim) for dim in dims) + nonnegative_constraint = TRUE_CONSTRAINT + for symbolic_dim in symbolic_dims: + nonnegative_constraint = nonnegative_constraint.intersect( + Constraint(symbolic_dim, True).lift_to_disjoint_constraint() + ) + if self.failure_tracker is not None: + self.failure_tracker.add_assertion( + nonnegative_constraint, + nonnegative_dims_js, + self.parent_transpiler._make_placeholder(), + ) + self.symbolic_tensors[sym] = SymbolicTensor( + sym, + tuple( + SymbolicSlice(Expression.from_constant(0), symbolic_dim) + for symbolic_dim in symbolic_dims + ), + self.parent_transpiler._make_placeholder(), + ) + + def make_scalar(self, sym: Sym): + self.scalar_symbols[sym] = sym + + def assign_tensor(self, arg_name: Sym, original_name: Sym): + self.symbolic_tensors[arg_name] = self.symbolic_tensors[original_name] + + def assign_scalar(self, arg_name: Sym, original_name: Sym): + self.scalar_symbols[arg_name] = self.scalar_symbols[original_name] + + def assign_scalar_from_context(self, scalar_sym: Sym, config: Config, field: str): + config_key = (config, field) + if config_key not in self.ctxt_symbols: + self.ctxt_symbols[config_key] = Sym("ctxt") + self.scalar_symbols[scalar_sym] = self.ctxt_symbols[config_key] + + def assign_window( + self, sym: Sym, window_expr: LoopIR.WindowExpr, in_bounds_js: str + ): + base_tensor = self.symbolic_tensors[window_expr.name] + in_bounds_constraint = TRUE_CONSTRAINT + window_dims = [] + window_idx_iter = iter(window_expr.idx) + for dim in base_tensor.dims: + if isinstance(dim, SymbolicPoint): + window_dims.append(dim) + else: + idx = next(window_idx_iter) + if isinstance(idx, LoopIR.Interval): + new_dim = SymbolicSlice( + self.cm.make_expression(idx.lo).add(dim.lower_bound), + self.cm.make_expression(idx.hi).add(dim.lower_bound), + ) + in_bounds_constraint = in_bounds_constraint.intersect( + Constraint( + new_dim.lower_bound.add(dim.lower_bound.negate()), True + ).lift_to_disjoint_constraint() + ).intersect( + Constraint( + dim.upper_bound.add(new_dim.upper_bound.negate()), True + ).lift_to_disjoint_constraint() + ) + window_dims.append(new_dim) + else: + new_dim = SymbolicPoint( + self.cm.make_expression(idx.pt).add(dim.lower_bound) + ) + in_bounds_constraint = in_bounds_constraint.intersect( + Constraint( + new_dim.index.add(dim.lower_bound.negate()), True + ).lift_to_disjoint_constraint() + ).intersect( + Constraint( + dim.upper_bound.add(new_dim.index.negate()).add( + Expression.from_constant(-1) + ), + True, + ).lift_to_disjoint_constraint() + ) + window_dims.append(new_dim) + + if self.failure_tracker is not None: + self.failure_tracker.add_assertion( + in_bounds_constraint, + in_bounds_js, + self.parent_transpiler._make_placeholder(), + ) + self.symbolic_tensors[sym] = SymbolicTensor( + base_tensor.name, + tuple(window_dims), + base_tensor.resize_placeholder, + ) + + def access_tensor( + self, + access_cursor: Node, + js_idxs: tuple[str, ...], + is_write: bool, + in_bounds_js: str, + ): + js_tensor = self.parent_transpiler._lookup_sym(access_cursor._node.name) + assert isinstance(js_tensor, Tensor) + symbolic_tensor = self.symbolic_tensors[access_cursor._node.name] + idx_expr_iter = iter(access_cursor._node.idx) + symbolic_idxs = [] + in_bounds_constraint = TRUE_CONSTRAINT + for dim in symbolic_tensor.dims: + if isinstance(dim, SymbolicSlice): + idx = self.cm.make_expression(next(idx_expr_iter)).add(dim.lower_bound) + in_bounds_constraint = in_bounds_constraint.intersect( + Constraint( + idx.negate().add(dim.upper_bound), True + ).lift_to_disjoint_constraint() + ) + symbolic_idxs.append(idx) + else: + symbolic_idxs.append(dim.index) + access_placeholder = self.parent_transpiler._make_placeholder() + access_args = ( + js_tensor, + js_idxs, + tuple(symbolic_idxs), + access_placeholder, + access_cursor, + self.current_node, + is_write, + ) + if self.failure_tracker is not None: + self.failure_tracker.add_assertion( + in_bounds_constraint, + in_bounds_js, + self.parent_transpiler._make_placeholder(), + ) + self.aliasing_tracker.access_tensor(*access_args) + if self.stage_mem_tracker is not None: + self.stage_mem_tracker.access_tensor(*access_args) + self.parallel_access_tracker.access_tensor(*access_args) + + def access_scalar(self, access_cursor: Node, is_write: bool): + access_placeholder = self.parent_transpiler._make_placeholder() + scalar_sym = self.scalar_symbols[access_cursor._node.name] + access_args = ( + scalar_sym, + access_placeholder, + access_cursor, + self.current_node, + is_write, + ) + self.aliasing_tracker.access_scalar(*access_args) + self.parallel_access_tracker.access_scalar(*access_args) + + def access_context(self, access_cursor: Node, is_write: bool): + access_placeholder = self.parent_transpiler._make_placeholder() + config_key = (access_cursor._node.config, access_cursor._node.field) + if config_key not in self.ctxt_symbols: + self.ctxt_symbols[config_key] = Sym("ctxt") + ctxt_sym = self.ctxt_symbols[config_key] + access_args = ( + ctxt_sym, + access_placeholder, + access_cursor, + self.current_node, + is_write, + ) + self.aliasing_tracker.access_scalar(*access_args) + self.parallel_access_tracker.access_scalar(*access_args) + + def make_skeleton(self) -> CoverageSkeleton: return CoverageSkeleton( - (self.root,), tuple(aliasable_accesses), frozenset(self.free_vars) + (self.root,), + self.aliasing_tracker.make_aliasing_accesses(), + ( + () + if self.failure_tracker is None + else self.failure_tracker.make_failures() + ), + ( + () + if self.stage_mem_tracker is None + else self.stage_mem_tracker.make_staging_overlaps() + ), + self.parallel_access_tracker.make_parallel_access_pairs(), + frozenset(self.free_vars), ) +def get_shape_cursor(type_cursor: Node) -> Block: + return ( + type_cursor._child_node("as_tensor") + if isinstance(type_cursor._node, LoopIR.WindowType) + else type_cursor + )._child_block("hi") + + class Transpiler: def __init__(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArgs] = None): self._name_lookup: dict[Sym, ExoValue] = {} @@ -119,7 +863,7 @@ def __init__(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArgs] = No self._buffer_args: list[Sym] = [] self._coverage_state: Optional[CoverageState] = None self._skeleton: Optional[CoverageSkeleton] = None - self.proc = proc # debug + self._proc = proc self._transpile_proc(proc, coverage_args) def get_javascript_template(self) -> Template: @@ -143,50 +887,63 @@ def get_size_param_name(self, tensor_name: Sym, dim_idx: int): def get_coverage_skeleton(self) -> Optional[CoverageSkeleton]: return self._skeleton + def get_proc(self) -> LoopIR.proc: + return self._proc + def _assert_at_runtime(self, expr: str): self._js_lines.append(f"if(!{expr})return [1,{CONTEXT_OBJECT_NAME},{{}}];") + # for CoverageState def _make_placeholder(self) -> int: placeholder_index = len(self._js_lines) self._js_lines.append("") return placeholder_index + # for CoverageState + def _lookup_sym(self, sym: Sym) -> ExoValue: + return self._name_lookup[sym] + def _transpile_proc(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArgs]): + self._buffer_args = [arg.name for arg in proc.args if arg.type.is_numeric()] + self._js_lines.append( + f'(({",".join(repr(arg) for arg in self._buffer_args)})=>{{' + ) + ctxt_placeholder = self._make_placeholder() + if coverage_args is not None: + self._coverage_state = CoverageState(coverage_args, self) arg_values = [] for arg in proc.args: if arg.type.is_numeric(): - self._buffer_args.append(arg.name) if arg.type.is_tensor_or_window(): value = Tensor( arg.name, - "0", tuple( Dimension( - f"${self.get_size_param_name(arg.name, dim_idx)}", - f"${self.get_stride_param_name(arg.name, dim_idx)}", + f"${size}", + f"${stride}", + Slice("0", f"${size}"), + ) + for size, stride in map( + lambda dim_idx: ( + self.get_size_param_name(arg.name, dim_idx), + self.get_stride_param_name(arg.name, dim_idx), + ), + range(len(arg.type.shape())), ) - for dim_idx in range(len(arg.type.shape())) ), - None, ) + if self._coverage_state is not None: + self._coverage_state.make_tensor( + arg.name, arg.type.shape(), "true" + ) else: value = Reference(repr(arg.name), False) + if self._coverage_state is not None: + self._coverage_state.make_scalar(arg.name) else: value = Constant(f"${repr(arg.name)}") arg_values.append(value) - self._js_lines.append( - f'(({",".join(repr(arg) for arg in self._buffer_args)})=>{{' - ) - ctxt_placeholder = self._make_placeholder() - if coverage_args is not None: - self._coverage_state = CoverageState( - coverage_args, self._make_placeholder() - ) - self._call_proc( - proc, - tuple(arg_values), - None if self._coverage_state is None else self._coverage_state.root, - ) + self._call_proc(Cursor.create(proc), tuple(arg_values), True) coverage_object = "" if self._coverage_state is not None: skeleton = self._coverage_state.make_skeleton() @@ -202,121 +959,99 @@ def _transpile_proc(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArg self._js_lines[ctxt_placeholder] = f"{CONTEXT_OBJECT_NAME}={{{configs}}}" def _call_proc( - self, - proc: LoopIR.proc, - arg_values: tuple[ExoValue, ...], - coverage_node: Optional[CoverageSkeletonNode], + self, proc_cursor: Node, arg_values: tuple[ExoValue, ...], top_level: bool ): - for arg, arg_value in zip(proc.args, arg_values): + for arg_cursor, arg_value in zip(proc_cursor._child_block("args"), arg_values): + arg = arg_cursor._node self._name_lookup[arg.name] = arg_value if arg.type.is_tensor_or_window(): assert isinstance(arg_value, Tensor) - for arg_dim, arg_dim_expr in zip(arg_value.dims, arg.type.shape()): - self._assert_at_runtime( - f"({arg_dim.size}=={self._transpile_expr(arg_dim_expr, None)})", + shape_matches_js = "&&".join( + f"((({arg_dim_slice.upper_bound})-({arg_dim_slice.lower_bound}))==({self._transpile_expr(arg_dim_cursor)}))" + for arg_dim_slice, arg_dim_cursor in zip( + ( + dim.window_idx + for dim in arg_value.dims + if isinstance(dim.window_idx, Slice) + ), + get_shape_cursor(arg_cursor._child_node("type")), + ) + ) + if self._coverage_state is not None and not top_level: + self._coverage_state.assert_shape_matches( + arg.name, arg.type.shape(), shape_matches_js ) + self._assert_at_runtime(shape_matches_js) - for pred in proc.preds: - self._assert_at_runtime(self._transpile_expr(pred, None)) + for pred_cursor in proc_cursor._child_block("preds"): + js_pred = self._transpile_expr(pred_cursor) + if self._coverage_state is not None and not top_level: + self._coverage_state.assert_predicate(pred_cursor._node, js_pred) + self._assert_at_runtime(js_pred) - for stmt in proc.body: - self._transpile_stmt(stmt, coverage_node) + if self._coverage_state is not None: + self._coverage_state.enter_proc_body() + for stmt_cursor in proc_cursor._child_block("body"): + self._transpile_stmt(stmt_cursor) + if self._coverage_state is not None: + self._coverage_state.exit_proc_body() - def _get_index_expr( + def _get_index_exprs( self, buf: Tensor, - idxs: Iterable[LoopIR.expr], - coverage_node: Optional[CoverageSkeletonNode], - ): - idx_exprs = tuple(self._transpile_expr(idx, coverage_node) for idx in idxs) - for idx_expr, dim in zip(idx_exprs, buf.dims): - self._assert_at_runtime(f"({idx_expr}<{dim.size}&&{idx_expr}>=0)") - relative_idx = reduce( - lambda dim1, dim2: f"{dim1}+{dim2}", - ( - f"Math.imul({idx_expr},{dim.stride})" - for idx_expr, dim in zip(idx_exprs, buf.dims) - ), - ) - return f"{relative_idx}+{buf.offset}" + idxs: Block, + ) -> tuple[str, ...]: + idx_expr_iter = iter(self._transpile_expr(idx) for idx in idxs) + idx_parts = [] + for dim in buf.dims: + if isinstance(dim.window_idx, Slice): + idx_expr = next(idx_expr_iter) + idx_parts.append(f"(({idx_expr})+({dim.window_idx.lower_bound}))") + else: + idx_parts.append(f"({dim.window_idx.index})") + return tuple(idx_parts) - def _make_scalar_access_fillers(self, access_sym: Sym) -> tuple[IndexedFiller, ...]: - assert self._coverage_state is not None - mark_placeholder = self._make_placeholder() - mark_stmt = f"{repr(access_sym)}=true;" - decl_stmt = f"let {repr(access_sym)}=false;" - return ( - IndexedFiller(mark_placeholder, mark_stmt), - IndexedFiller(self._coverage_state.cov_placeholder, decl_stmt), + def _get_in_bounds_condition( + self, index_exprs: tuple[str, ...], buf: Tensor + ) -> str: + return "&&".join( + f"(({index_expr})>=({dim.window_idx.lower_bound})&&({index_expr})<({dim.window_idx.upper_bound}))" + for index_expr, dim in zip(index_exprs, buf.dims) + if isinstance(dim.window_idx, Slice) ) - def _make_tensor_access_fillers( - self, access_sym: Sym, buffer: Tensor, idx: Iterable[LoopIR.expr] - ) -> tuple[IndexedFiller, ...]: - assert self._coverage_state is not None - mark_stmt = f"{repr(access_sym)}[{self._get_index_expr(buffer, idx, None)}]=1;" - mark_placeholder = self._make_placeholder() - base_buffer = self._name_lookup[buffer.name] - assert isinstance(base_buffer, Tensor) - base_dims = base_buffer.dims - base_size = reduce( - lambda dim1, dim2: f"Math.imul({dim1},{dim2})", - (dim.size for dim in base_dims), - ) - if buffer.resize_placeholder is None: - decl_stmt = f"let {repr(access_sym)}=new ArrayBuffer({base_size});" - return ( - IndexedFiller(mark_placeholder, mark_stmt), - IndexedFiller(self._coverage_state.cov_placeholder, decl_stmt), - ) - else: - temp_sym = Sym("temp") - decl_stmt = f"let {repr(access_sym)}=new ArrayBuffer(1,{{maxByteLength:{INITIAL_DYNAMIC_SIZE}}});" - resize_stmt = f"while({base_size}>{repr(access_sym)}.maxByteLength){{let {repr(temp_sym)}=new ArrayBuffer({repr(access_sym)}.byteLength,{{maxByteLength:2*{repr(access_sym)}.maxByteLength}});for(let i=0;i<{repr(access_sym)}.byteLength;i++){repr(temp_sym)}[i]={repr(access_sym)}[i];{repr(access_sym)}={repr(temp_sym)}}};{repr(access_sym)}.resize(Math.max({base_size},{repr(access_sym)}.byteLength));" - return ( - IndexedFiller(mark_placeholder, mark_stmt), - IndexedFiller(self._coverage_state.cov_placeholder, decl_stmt), - IndexedFiller(buffer.resize_placeholder, resize_stmt), - ) - def _transpile_stmt( - self, stmt: LoopIR.stmt, coverage_node: Optional[CoverageSkeletonNode] + self, + stmt_cursor: Node, ): + if self._coverage_state is not None: + self._coverage_state.enter_stmt(stmt_cursor) + stmt = stmt_cursor._node if isinstance(stmt, (LoopIR.Assign, LoopIR.Reduce)): lhs_buffer = self._name_lookup[stmt.name] - - if self._coverage_state is not None and coverage_node is not None: - write_sym = Sym("write") - if stmt.name not in self._coverage_state.buffer_writes: - self._coverage_state.buffer_writes[ - lhs_buffer.name if isinstance(lhs_buffer, Tensor) else stmt.name - ] = [] - self._coverage_state.buffer_writes[ - lhs_buffer.name if isinstance(lhs_buffer, Tensor) else stmt.name - ].append( - MemoryAccess( - write_sym, - coverage_node, - tuple( - self._coverage_state.cm.make_expression(idx) - for idx in stmt.idx - ), - ( - self._make_tensor_access_fillers( - write_sym, lhs_buffer, stmt.idx - ) - if isinstance(lhs_buffer, Tensor) - else self._make_scalar_access_fillers(write_sym) - ), - ) - ) - rhs = self._transpile_expr(stmt.rhs, coverage_node) + rhs = self._transpile_expr(stmt_cursor._child_node("rhs")) if isinstance(lhs_buffer, Reference): lhs = ( lhs_buffer.name if lhs_buffer.is_config else f"{lhs_buffer.name}[0]" ) + if self._coverage_state is not None: + self._coverage_state.access_scalar(stmt_cursor, True) elif isinstance(lhs_buffer, Tensor): - lhs = f"{repr(lhs_buffer.name)}[{self._get_index_expr(lhs_buffer, stmt.idx, coverage_node)}]" + index_exprs = self._get_index_exprs( + lhs_buffer, + stmt_cursor._child_block("idx"), + ) + index = f"+".join( + f"Math.imul({dim.stride},{index_expr})" + for dim, index_expr in zip(lhs_buffer.dims, index_exprs) + ) + lhs = f"{repr(lhs_buffer.name)}[{index}]" + in_bounds_js = self._get_in_bounds_condition(index_exprs, lhs_buffer) + if self._coverage_state is not None: + self._coverage_state.access_tensor( + stmt_cursor, index_exprs, True, in_bounds_js + ) + self._assert_at_runtime(in_bounds_js) else: assert False if isinstance(stmt, LoopIR.Assign): @@ -325,216 +1060,213 @@ def _transpile_stmt( self._js_lines.append(f"{lhs}+={rhs};") elif isinstance(stmt, LoopIR.WriteConfig): config_name = self.get_config_param_name(stmt.config, stmt.field) - rhs = self._transpile_expr(stmt.rhs, coverage_node) + rhs = self._transpile_expr(stmt_cursor._child_node("rhs")) self._js_lines.append(f'{CONTEXT_OBJECT_NAME}["{config_name}"]={rhs};') self._configs.add((stmt.config, stmt.field)) + if self._coverage_state is not None: + self._coverage_state.access_context(stmt_cursor, True) elif isinstance(stmt, LoopIR.Pass): pass elif isinstance(stmt, LoopIR.If): - cond = self._transpile_expr(stmt.cond, coverage_node) + cond = self._transpile_expr(stmt_cursor._child_node("cond")) self._js_lines.append(f"if({cond}){{") - if self._coverage_state is not None and coverage_node is not None: - true_sym = Sym("true_case") - false_sym = Sym("false_case") - true_placeholder = self._make_placeholder() - cond_constraint = self._coverage_state.cm.make_constraint(stmt.cond) - true_node = CoverageSkeletonNode( - true_sym, - (coverage_node, cond_constraint), - ( - IndexedFiller( - self._coverage_state.cov_placeholder, - f"let {repr(true_sym)}=false;", - ), - IndexedFiller(true_placeholder, f"{repr(true_sym)}=true;"), - ), - ) - for body_stmt in stmt.body: - self._transpile_stmt(body_stmt, true_node) + def transpile_if_body(): + for body_cursor in stmt_cursor._child_block("body"): + self._transpile_stmt(body_cursor) self._js_lines.append("}else{") - false_placeholder = self._make_placeholder() - false_node = CoverageSkeletonNode( - false_sym, - (coverage_node, cond_constraint.invert()), - ( - IndexedFiller( - self._coverage_state.cov_placeholder, - f"let {repr(false_sym)}=false;", - ), - IndexedFiller(false_placeholder, f"{repr(false_sym)}=true;"), - ), + + def transpile_else_body(): + for else_cursor in stmt_cursor._child_block("orelse"): + self._transpile_stmt(else_cursor) + self._js_lines.append("}") + + if self._coverage_state is not None: + self._coverage_state.enter_if( + stmt, transpile_if_body, transpile_else_body ) - for else_stmt in stmt.orelse: - self._transpile_stmt(else_stmt, false_node) - new_branch = CoverageSkeletonBranch(true_node, false_node) - coverage_node.branches.append(new_branch) else: - for body_stmt in stmt.body: - self._transpile_stmt(body_stmt, None) - self._js_lines.append("}else{") - for else_stmt in stmt.orelse: - self._transpile_stmt(else_stmt, None) - self._js_lines.append("}") + transpile_if_body() + transpile_else_body() elif isinstance(stmt, LoopIR.For): iter_name = repr(stmt.iter) - iter_lo = self._transpile_expr(stmt.lo, coverage_node) - iter_hi = self._transpile_expr(stmt.hi, coverage_node) + iter_lo = self._transpile_expr(stmt_cursor._child_node("lo")) + iter_hi = self._transpile_expr(stmt_cursor._child_node("hi")) self._name_lookup[stmt.iter] = Constant(iter_name) - body_child, skip_child = None, None - if self._coverage_state is not None and coverage_node is not None: - body_sym = Sym("body") - skip_sym = Sym("skip") - loop_placeholder = self._make_placeholder() - body_constraint = ( - self._coverage_state.cm.make_constraint_from_inequality( - stmt.lo, stmt.iter, "<=" - ) - .lift_to_disjoint_constraint() - .intersect( - self._coverage_state.cm.make_constraint_from_inequality( - stmt.iter, stmt.hi, "<" - ).lift_to_disjoint_constraint() - ) - ) - skip_constraint = ( - self._coverage_state.cm.make_constraint_from_inequality( - stmt.lo, stmt.hi, ">=" - ).lift_to_disjoint_constraint() - ) - body_child = CoverageSkeletonNode( - body_sym, - (coverage_node, body_constraint), - ( - IndexedFiller( - self._coverage_state.cov_placeholder, - f"let {repr(body_sym)}=false;", - ), - IndexedFiller( - loop_placeholder, - f"{repr(body_sym)}||=({iter_lo}<{iter_hi});", - ), - ), + def transpile_loop_body(): + self._js_lines.append( + f"for(let {iter_name}={iter_lo};{iter_name}<{iter_hi};{iter_name}++){{" ) - skip_child = CoverageSkeletonNode( - skip_sym, - (coverage_node, skip_constraint), - ( - IndexedFiller( - self._coverage_state.cov_placeholder, - f"let {repr(skip_sym)}=false;", - ), - IndexedFiller( - loop_placeholder, - f"{repr(skip_sym)}||=({iter_lo}>={iter_hi});", - ), - ), + for body_cursor in stmt_cursor._child_block("body"): + self._transpile_stmt(body_cursor) + + if self._coverage_state is not None: + self._coverage_state.enter_loop( + stmt, iter_lo, iter_hi, transpile_loop_body ) - self._coverage_state.free_vars.append(stmt.iter) - new_loop = CoverageSkeletonBranch(body_child, skip_child) - coverage_node.branches.append(new_loop) + else: + transpile_loop_body() - self._js_lines.append( - f"for(let {iter_name}={iter_lo};{iter_name}<{iter_hi};{iter_name}++){{" - ) - for body_stmt in stmt.body: - self._transpile_stmt(body_stmt, body_child) self._js_lines.append("}") elif isinstance(stmt, LoopIR.Alloc): assert stmt.type.is_numeric() if stmt.type.is_tensor_or_window(): - tensor_name = repr(stmt.name) + tensor_name: Sym = stmt.name buffer_type = lookup_loopir_type( stmt.type.basetype() ).javascript_array_type - dim_exprs = tuple( - self._transpile_expr(dim, coverage_node) - for dim in stmt.type.shape() - ) - for dim_expr in dim_exprs: - self._assert_at_runtime(f"({dim_expr}>=0)") - buffer_size = reduce( - lambda dim1, dim2: f"Math.imul({dim1},{dim2})", dim_exprs + shape_cursor = get_shape_cursor(stmt_cursor._child_node("type")) + dims = len(shape_cursor) + self._js_lines.append( + "".join( + f"let {self.get_size_param_name(tensor_name, dim_idx)}={self._transpile_expr(dim_cursor)};" + for dim_idx, dim_cursor in enumerate(shape_cursor) + ) ) self._js_lines.append( - f"let {tensor_name}=new {buffer_type}({buffer_size});" + "".join( + f'let {self.get_stride_param_name(tensor_name, dim_idx)}={f"1" if dim_idx + 1 == dims else f"Math.imul({self.get_stride_param_name(tensor_name, dim_idx + 1)},{self.get_size_param_name(tensor_name,dim_idx + 1)})"};' + for dim_idx in reversed(range(dims)) + ) + ) + nonnegative_dims_js = "&&".join( + f"({self.get_size_param_name(tensor_name, dim_idx)}>=0)" + for dim_idx in range(dims) ) - resize_placeholder = len(self._js_lines) - self._js_lines.append("") - dimensions: list[Dimension] = [] - for dim_idx, dim_expr in enumerate(dim_exprs): - self._assert_at_runtime(f"({dim_expr}>=0)") - stride_expr = reduce( - lambda dim1, dim2: f"Math.imul({dim1},{dim2})", - dim_exprs[dim_idx + 1 :], - "1", + if self._coverage_state is not None: + self._coverage_state.make_tensor( + tensor_name, stmt.type.shape(), nonnegative_dims_js ) - dimensions.append(Dimension(dim_expr, stride_expr)) + self._assert_at_runtime(nonnegative_dims_js) + self._js_lines.append( + f"let {repr(tensor_name)}=new {buffer_type}(Math.imul({self.get_size_param_name(tensor_name, 0)},{self.get_stride_param_name(tensor_name, 0)}));" + ) self._name_lookup[stmt.name] = Tensor( - stmt.name, "0", tuple(dimensions), resize_placeholder + stmt.name, + tuple( + Dimension(size, stride, Slice("0", size)) + for size, stride in map( + lambda dim_idx: ( + self.get_size_param_name(tensor_name, dim_idx), + self.get_stride_param_name(tensor_name, dim_idx), + ), + range(dims), + ) + ), ) else: ref_name = repr(stmt.name) buffer_type = lookup_loopir_type(stmt.type).javascript_array_type + if self._coverage_state is not None: + self._coverage_state.make_scalar(stmt.name) self._js_lines.append(f"let {ref_name}=new {buffer_type}(1);") self._name_lookup[stmt.name] = Reference(ref_name, False) elif isinstance(stmt, LoopIR.Free): pass elif isinstance(stmt, LoopIR.Call): self._call_proc( - stmt.f, + stmt_cursor._child_node("f"), tuple( ( - self._transpile_buffer_arg(arg_expr, coverage_node) - if arg_expr.type.is_numeric() - else Constant(self._transpile_expr(arg_expr, coverage_node)) + self._transpile_buffer_arg( + arg_val_cursor, arg_name_cursor._node.name + ) + if arg_val_cursor._node.type.is_numeric() + else Constant(self._transpile_expr(arg_val_cursor)) + ) + for arg_val_cursor, arg_name_cursor in zip( + stmt_cursor._child_block("args"), + stmt_cursor._child_node("f")._child_block("args"), ) - for arg_expr in stmt.args ), - coverage_node, + False, ) elif isinstance(stmt, LoopIR.WindowStmt): self._name_lookup[stmt.name] = self._transpile_buffer_arg( - stmt.rhs, coverage_node + stmt_cursor._child_node("rhs"), stmt.name ) else: assert False, "unsupported stmt" + if self._coverage_state is not None: + self._coverage_state.exit_stmt(stmt_cursor) + def _transpile_buffer_arg( - self, expr: LoopIR.expr, coverage_node: Optional[CoverageSkeletonNode] + self, expr_cursor: Node, new_name: Sym ) -> Union[Tensor, Reference]: + expr = expr_cursor._node if isinstance(expr, LoopIR.Read): assert len(expr.idx) == 0 buf = self._name_lookup[expr.name] assert isinstance(buf, (Tensor, Reference)) + if self._coverage_state is not None: + if isinstance(buf, Tensor): + self._coverage_state.assign_tensor(new_name, expr.name) + else: + self._coverage_state.assign_scalar(new_name, expr.name) return buf elif isinstance(expr, LoopIR.WindowExpr): base = self._name_lookup[expr.name] assert isinstance(base, Tensor) - offset_expr = base.offset window_dims = [] - for idx, dim in zip(expr.idx, base.dims): - if isinstance(idx, LoopIR.Interval): - lo_expr = self._transpile_expr(idx.lo, coverage_node) - hi_expr = self._transpile_expr(idx.hi, coverage_node) - self._assert_at_runtime( - f"(0<={lo_expr}&&{lo_expr}<={hi_expr}&&{hi_expr}<={dim.size})" - ) - offset_expr = f"({offset_expr}+Math.imul({lo_expr},{dim.stride}))" - size_expr = f"({hi_expr}-{lo_expr})" - window_dims.append(Dimension(size_expr, dim.stride)) - elif isinstance(idx, LoopIR.Point): - pt_expr = self._transpile_expr(idx.pt, coverage_node) - self._assert_at_runtime(f"(0<={pt_expr}&&{pt_expr}<{dim.size})") - offset_expr = f"({offset_expr}+Math.imul({pt_expr},{dim.stride}))" + in_bounds_conds = [] + idx_cursor_iter = iter(expr_cursor._child_block("idx")) + for dim in base.dims: + if isinstance(dim.window_idx, Point): + window_dims.append(dim) else: - assert False, "not a window index" - return Tensor( - base.name, offset_expr, tuple(window_dims), base.resize_placeholder - ) + idx_cursor = next(idx_cursor_iter) + idx = idx_cursor._node + if isinstance(idx, LoopIR.Interval): + lo_expr = self._transpile_expr( + idx_cursor._child_node("lo"), + ) + hi_expr = self._transpile_expr( + idx_cursor._child_node("hi"), + ) + lo_sym, hi_sym = Sym("lo"), Sym("hi") + self._js_lines.append( + f"let {repr(lo_sym)}=({lo_expr})+({dim.window_idx.lower_bound});" + ) + self._js_lines.append( + f"let {repr(hi_sym)}=({hi_expr})+({dim.window_idx.lower_bound});" + ) + in_bounds_conds.append( + f"(({dim.window_idx.lower_bound})<=({repr(lo_sym)})&&({repr(lo_sym)})<=({repr(hi_sym)})&&({repr(hi_sym)})<=({dim.window_idx.upper_bound}))" + ) + window_dims.append( + Dimension( + dim.size, dim.stride, Slice(repr(lo_sym), repr(hi_sym)) + ) + ) + elif isinstance(idx, LoopIR.Point): + pt_expr = self._transpile_expr( + idx_cursor._child_node("pt"), + ) + index_sym = Sym("idx") + self._js_lines.append( + f"let {repr(index_sym)}=({pt_expr})+({dim.window_idx.lower_bound});" + ) + in_bounds_conds.append( + f"(({dim.window_idx.lower_bound})<={repr(index_sym)}&&{repr(index_sym)}<({dim.window_idx.upper_bound}))" + ) + window_dims.append( + Dimension(dim.size, dim.stride, Point(repr(index_sym))) + ) + else: + assert False, "not a window index" + in_bounds_js = "&&".join(in_bounds_conds) + if self._coverage_state is not None: + self._coverage_state.assign_window(new_name, expr, in_bounds_js) + self._assert_at_runtime(in_bounds_js) + return Tensor(base.name, tuple(window_dims)) elif isinstance(expr, LoopIR.ReadConfig): self._configs.add((expr.config, expr.field)) + if self._coverage_state is not None: + self._coverage_state.assign_scalar_from_context( + new_name, expr.config, expr.field + ) return Reference( f'{CONTEXT_OBJECT_NAME}["{self.get_config_param_name(expr.config, expr.field)}"]', True, @@ -543,36 +1275,31 @@ def _transpile_buffer_arg( assert False, "unsupported buffer expression" def _transpile_expr( - self, expr: LoopIR.expr, coverage_node: Optional[CoverageSkeletonNode] + self, + expr_cursor: Node, ) -> str: + expr = expr_cursor._node if isinstance(expr, LoopIR.Read): buf = self._name_lookup[expr.name] - if self._coverage_state is not None and coverage_node is not None: - read_sym = Sym("read") - if expr.name not in self._coverage_state.buffer_reads: - self._coverage_state.buffer_reads[ - buf.name if isinstance(buf, Tensor) else expr.name - ] = [] - self._coverage_state.buffer_reads[ - buf.name if isinstance(buf, Tensor) else expr.name - ].append( - MemoryAccess( - read_sym, - coverage_node, - tuple( - self._coverage_state.cm.make_expression(idx) - for idx in expr.idx - ), - ( - self._make_tensor_access_fillers(read_sym, buf, expr.idx) - if isinstance(buf, Tensor) - else self._make_scalar_access_fillers(read_sym) - ), - ) - ) if isinstance(buf, Tensor): - return f"{repr(buf.name)}[{self._get_index_expr(buf, expr.idx, coverage_node)}]" + index_exprs = self._get_index_exprs( + buf, + expr_cursor._child_block("idx"), + ) + index = f"+".join( + f"Math.imul({dim.stride},{index_expr})" + for dim, index_expr in zip(buf.dims, index_exprs) + ) + in_bounds_js = self._get_in_bounds_condition(index_exprs, buf) + if self._coverage_state is not None: + self._coverage_state.access_tensor( + expr_cursor, index_exprs, False, in_bounds_js + ) + self._assert_at_runtime(in_bounds_js) + return f"{repr(buf.name)}[{index}]" elif isinstance(buf, Reference): + if self._coverage_state is not None: + self._coverage_state.access_scalar(expr_cursor, False) return buf.name if buf.is_config else f"{buf.name}[0]" else: return buf.name @@ -584,10 +1311,10 @@ def _transpile_expr( else: assert False, "unexpected const type" elif isinstance(expr, LoopIR.USub): - return f"(-{self._transpile_expr(expr.arg, coverage_node)})" + return f"(-{self._transpile_expr(expr_cursor._child_node('arg'))})" elif isinstance(expr, LoopIR.BinOp): - lhs = self._transpile_expr(expr.lhs, coverage_node) - rhs = self._transpile_expr(expr.rhs, coverage_node) + lhs = self._transpile_expr(expr_cursor._child_node("lhs")) + rhs = self._transpile_expr(expr_cursor._child_node("rhs")) is_int = ( isinstance(expr.type, (T.INT8, T.UINT8, T.UINT16, T.INT32)) or not expr.type.is_numeric() @@ -614,16 +1341,23 @@ def _transpile_expr( return val elif isinstance(expr, LoopIR.Extern): return expr.f.transpile( - tuple(self._transpile_expr(arg, coverage_node) for arg in expr.args) + tuple( + self._transpile_expr(arg_cursor) + for arg_cursor in expr_cursor._child_block("args") + ) ) elif isinstance(expr, LoopIR.WindowExpr): assert False, "unexpected window expr" elif isinstance(expr, LoopIR.StrideExpr): buf = self._name_lookup[expr.name] assert isinstance(buf, Tensor) - return buf.dims[expr.dim].stride + return tuple(dim for dim in buf.dims if isinstance(dim.window_idx, Slice))[ + expr.dim + ].stride elif isinstance(expr, LoopIR.ReadConfig): self._configs.add((expr.config, expr.field)) + if self._coverage_state is not None: + self._coverage_state.access_context(expr_cursor, False) return f'{CONTEXT_OBJECT_NAME}["{self.get_config_param_name(expr.config, expr.field)}"]' else: assert False, "unexpected expr" diff --git a/src/exo/backend/coverage.py b/src/exo/backend/coverage.py index 06071aacb..de93c13d1 100644 --- a/src/exo/backend/coverage.py +++ b/src/exo/backend/coverage.py @@ -6,12 +6,14 @@ from ..rewrite.constraint_solver import ( Constraint, ConstraintMaker, + ConstraintTerm, DisjointConstraint, TRUE_CONSTRAINT, Expression, Solution, ) from ..core.prelude import Sym +from ..core.internal_cursors import Node, NodePath @dataclass @@ -191,9 +193,9 @@ def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: def make_renamed_constraint_and_indices( self, state: CoverageSolverState - ) -> tuple[DisjointConstraint, tuple[Expression, ...]]: + ) -> tuple[DisjointConstraint, tuple[Expression, ...], dict[Sym, Sym]]: path_constraint = self.node.get_complete_constraint() - sym_renaming, _ = state.cm.rename_sym_set( + sym_renaming, var_renaming = state.cm.rename_sym_set( path_constraint.collect_syms().union( *(index_expr.collect_syms() for index_expr in self.index) ), @@ -202,6 +204,7 @@ def make_renamed_constraint_and_indices( return ( path_constraint.rename_syms(sym_renaming), tuple(index_expr.rename_syms(sym_renaming) for index_expr in self.index), + var_renaming, ) @@ -253,10 +256,12 @@ def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: ( access1_path_constraint, access1_indices, + _, ) = self.access1.make_renamed_constraint_and_indices(state) ( access2_path_constraint, access2_indices, + _, ) = self.access2.make_renamed_constraint_and_indices(state) path_constraints = access1_path_constraint.intersect( access2_path_constraint @@ -265,9 +270,7 @@ def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: for index1, index2 in zip(access1_indices, access2_indices): alias_constraint = alias_constraint.intersect( Constraint( - Expression( - tuple(term.negate() for term in index1.terms) + index2.terms - ), + index1.negate().add(index2), False, ).lift_to_disjoint_constraint() ) @@ -291,16 +294,364 @@ def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: return state +@dataclass +class FailureCondition: + coverage_sym: Sym + fail_cond: DisjointConstraint + node: CoverageSkeletonNode + indexed_fillers: tuple[IndexedFiller, ...] + visited_failure: bool = False # mutable + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset((self.coverage_sym,)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + failed = coverage_result[repr(self.coverage_sym)] + assert isinstance(failed, bool) + self.visited_failure |= failed + + def get_coverage_progress(self) -> CoverageProgress: + return CoverageProgress( + 1 if self.visited_failure else 0, + 1, + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + if not self.visited_failure: + path_constraint = self.node.get_complete_constraint().intersect( + self.fail_cond + ) + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms(), + state.free_vars, + ) + new_constraint = state.current_constraint.intersect( + path_constraint.rename_syms(sym_renaming) + ) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + self.visited_failure = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + return state + + +@dataclass +class SymbolicPoint: + index: Expression + + def collect_syms(self) -> frozenset[Sym]: + return self.index.collect_syms() + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "SymbolicPoint": + return SymbolicPoint(self.index.rename_syms(lookup)) + + +@dataclass +class SymbolicSlice: + lower_bound: Expression + upper_bound: Expression + + def collect_syms(self) -> frozenset[Sym]: + return self.lower_bound.collect_syms() | self.upper_bound.collect_syms() + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "SymbolicSlice": + return SymbolicSlice( + self.lower_bound.rename_syms(lookup), self.upper_bound.rename_syms(lookup) + ) + + +SymbolicWindowIndex = Union[SymbolicPoint, SymbolicSlice] + + +# for stage_mem +@dataclass +class StagingOverlap: + overlap_sym: Sym + disjoint_sym: Sym + staged_window: tuple[SymbolicWindowIndex, ...] + access_window: tuple[SymbolicWindowIndex, ...] + node: CoverageSkeletonNode + access_cursor: NodePath + indexed_fillers: tuple[IndexedFiller, ...] + has_overlap: bool = False + has_disjoint_access: bool = False + visited_overlap: bool = False + visited_disjoint: bool = False + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset((self.overlap_sym, self.disjoint_sym)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + overlap = coverage_result[repr(self.overlap_sym)] + disjoint = coverage_result[repr(self.disjoint_sym)] + assert isinstance(overlap, bool) and isinstance(disjoint, bool) + self.has_overlap |= overlap + self.visited_overlap |= overlap + self.has_disjoint_access |= disjoint + self.visited_disjoint |= disjoint + + def get_coverage_progress(self) -> CoverageProgress: + return CoverageProgress( + (1 if self.visited_overlap else 0) + (1 if self.visited_disjoint else 0), + 2, + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + uncovered_overlap = None + if self.visited_overlap and not self.visited_disjoint: + uncovered_overlap = False + elif self.visited_disjoint and not self.visited_overlap: + uncovered_overlap = True + + if uncovered_overlap is not None: + path_constraint = self.node.get_complete_constraint() + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms().union( + *( + (access_window_idx.collect_syms()) + for access_window_idx in self.access_window + ), + *( + staged_window_idx.collect_syms() + for staged_window_idx in self.staged_window + ), + ), + state.free_vars, + ) + path_constraint = path_constraint.rename_syms(sym_renaming) + overlap_constraint = TRUE_CONSTRAINT + for access_idx, staged_idx in zip( + map(lambda idx: idx.rename_syms(sym_renaming), self.access_window), + map(lambda idx: idx.rename_syms(sym_renaming), self.staged_window), + ): + if isinstance(staged_idx, SymbolicPoint) and isinstance( + access_idx, SymbolicPoint + ): + index_overlap_constraint = Constraint( + access_idx.index.negate().add(staged_idx.index), + False, + ).lift_to_disjoint_constraint() + elif isinstance(staged_idx, SymbolicSlice) and isinstance( + access_idx, SymbolicPoint + ): + index_overlap_constraint = ( + Constraint( + staged_idx.lower_bound.negate().add(access_idx.index), + True, + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + access_idx.index.negate() + .add(staged_idx.upper_bound) + .add(Expression.from_constant(-1)), + True, + ).lift_to_disjoint_constraint() + ) + ) + elif isinstance(staged_idx, SymbolicPoint) and isinstance( + access_idx, SymbolicSlice + ): + index_overlap_constraint = ( + Constraint( + access_idx.lower_bound.negate().add(staged_idx.index), + True, + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + staged_idx.index.negate() + .add(access_idx.upper_bound) + .add(Expression.from_constant(-1)), + True, + ).lift_to_disjoint_constraint() + ) + ) + elif isinstance(staged_idx, SymbolicSlice) and isinstance( + access_idx, SymbolicSlice + ): + index_overlap_constraint = ( + Constraint( + staged_idx.lower_bound.negate() + .add(access_idx.upper_bound) + .add(Expression.from_constant(-1)), + True, + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + access_idx.lower_bound.negate() + .add(staged_idx.upper_bound) + .add(Expression.from_constant(-1)), + True, + ).lift_to_disjoint_constraint() + ) + ) + else: + assert False + overlap_constraint = overlap_constraint.intersect( + index_overlap_constraint + ) + if not uncovered_overlap: + overlap_constraint = overlap_constraint.invert() + new_constraint = state.current_constraint.intersect( + path_constraint + ).intersect(overlap_constraint) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + if uncovered_overlap: + self.visited_overlap = True + else: + self.visited_disjoint = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + return state + + +@dataclass +class ParallelAccess: + node: CoverageSkeletonNode + index: tuple[Expression, ...] + indexed_fillers: tuple[IndexedFiller, ...] + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def make_renamed_constraint_and_indices( + self, state: CoverageSolverState + ) -> tuple[DisjointConstraint, tuple[Expression, ...], dict[Sym, Sym]]: + path_constraint = self.node.get_complete_constraint() + sym_renaming, var_renaming = state.cm.rename_sym_set( + path_constraint.collect_syms().union( + *(index_expr.collect_syms() for index_expr in self.index) + ), + state.free_vars, + ) + return ( + path_constraint.rename_syms(sym_renaming), + tuple(index_expr.rename_syms(sym_renaming) for index_expr in self.index), + var_renaming, + ) + + +@dataclass +class ParallelAccessPair: + coverage_sym: Sym + iter_sym: Sym + access1: ParallelAccess + access2: ParallelAccess + indexed_fillers: tuple[IndexedFiller, ...] + has_aliased: bool = False + visited_aliasing: bool = False # mutable + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + yield from self.access1.get_indexed_fillers() + yield from self.access2.get_indexed_fillers() + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset((self.coverage_sym,)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + aliased = coverage_result[repr(self.coverage_sym)] + assert isinstance(aliased, bool) + self.has_aliased |= aliased + self.visited_aliasing |= aliased + + def get_coverage_progress(self) -> CoverageProgress: + return CoverageProgress( + 1 if self.visited_aliasing else 0, + 1, + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + if not self.visited_aliasing: + ( + access1_path_constraint, + access1_indices, + access1_var_renaming, + ) = self.access1.make_renamed_constraint_and_indices(state) + ( + access2_path_constraint, + access2_indices, + access2_var_renaming, + ) = self.access2.make_renamed_constraint_and_indices(state) + path_constraints = access1_path_constraint.intersect( + access2_path_constraint + ) + alias_constraint = TRUE_CONSTRAINT + for index1, index2 in zip(access1_indices, access2_indices): + alias_constraint = alias_constraint.intersect( + Constraint( + index1.negate().add(index2), + False, + ).lift_to_disjoint_constraint() + ) + + different_iteration_constraint = TRUE_CONSTRAINT + if ( + self.iter_sym in access1_var_renaming + and self.iter_sym in access2_var_renaming + ): + different_iteration_constraint = Constraint( + Expression.from_sym(access1_var_renaming[self.iter_sym]) + .negate() + .add(Expression.from_sym(access2_var_renaming[self.iter_sym])), + False, + ).invert() + + new_constraint = ( + state.current_constraint.intersect(path_constraints) + .intersect(alias_constraint) + .intersect(different_iteration_constraint) + ) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + self.visited_aliasing = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + return state + + @dataclass class CoverageSkeleton: roots: tuple[CoverageSkeletonNode, ...] aliasable_accesses: tuple[MemoryAccessPair, ...] + failure_conditions: tuple[FailureCondition, ...] + staging_overlaps: tuple[StagingOverlap, ...] + parallel_accesses: tuple[ParallelAccessPair, ...] free_vars: frozenset[Sym] def merge(self, other: "CoverageSkeleton") -> "CoverageSkeleton": return CoverageSkeleton( self.roots + other.roots, self.aliasable_accesses + other.aliasable_accesses, + self.failure_conditions + other.failure_conditions, + self.staging_overlaps + other.staging_overlaps, + self.parallel_accesses + other.parallel_accesses, self.free_vars | other.free_vars, ) @@ -309,6 +660,12 @@ def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: yield from root.get_indexed_fillers() for aliasable_access in self.aliasable_accesses: yield from aliasable_access.get_indexed_fillers() + for failure_condition in self.failure_conditions: + yield from failure_condition.get_indexed_fillers() + for staging_overlap in self.staging_overlaps: + yield from staging_overlap.get_indexed_fillers() + for parallel_access in self.parallel_accesses: + yield from parallel_access.get_indexed_fillers() def get_coverage_syms(self) -> frozenset[Sym]: return frozenset().union( @@ -317,6 +674,18 @@ def get_coverage_syms(self) -> frozenset[Sym]: aliasable_access.get_coverage_syms() for aliasable_access in self.aliasable_accesses ), + *tuple( + failure_condition.get_coverage_syms() + for failure_condition in self.failure_conditions + ), + *tuple( + staging_overlap.get_coverage_syms() + for staging_overlap in self.staging_overlaps + ), + *tuple( + parallel_access.get_coverage_syms() + for parallel_access in self.parallel_accesses + ), ) def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): @@ -324,6 +693,12 @@ def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): root_node.update_coverage(coverage_result) for aliasable_access in self.aliasable_accesses: aliasable_access.update_coverage(coverage_result) + for failure_condition in self.failure_conditions: + failure_condition.update_coverage(coverage_result) + for staging_overlap in self.staging_overlaps: + staging_overlap.update_coverage(coverage_result) + for parallel_access in self.parallel_accesses: + parallel_access.update_coverage(coverage_result) def get_coverage_progress(self) -> CoverageProgress: result = CoverageProgress(0, 0) @@ -331,6 +706,12 @@ def get_coverage_progress(self) -> CoverageProgress: result = root_node.get_coverage_progress() for aliasable_access in self.aliasable_accesses: result = result.merge(aliasable_access.get_coverage_progress()) + for failure_condition in self.failure_conditions: + result = result.merge(failure_condition.get_coverage_progress()) + for staging_overlap in self.staging_overlaps: + result = result.merge(staging_overlap.get_coverage_progress()) + for parallel_access in self.parallel_accesses: + result = result.merge(parallel_access.get_coverage_progress()) return result def solve_constraint_with_coverage( @@ -355,6 +736,12 @@ def solve_constraint_with_coverage( bound, search_limit, ) + for parallel_access in self.parallel_accesses: + state = parallel_access.solve_coverage(state) + for staging_overlap in self.staging_overlaps: + state = staging_overlap.solve_coverage(state) + for failure_condition in self.failure_conditions: + state = failure_condition.solve_coverage(state) for aliasable_access in self.aliasable_accesses: state = aliasable_access.solve_coverage(state) for root_node in self.roots: diff --git a/src/exo/core/internal_cursors.py b/src/exo/core/internal_cursors.py index 1d4c98a8f..3abae4e29 100644 --- a/src/exo/core/internal_cursors.py +++ b/src/exo/core/internal_cursors.py @@ -581,6 +581,11 @@ def forward(cursor: Node): return forward +@dataclass(frozen=True) +class NodePath: + path: tuple[tuple[str, Optional[int]]] + + @dataclass class Node(Cursor): _path: list[tuple[str, Optional[int]]] @@ -605,6 +610,13 @@ def _node(self): return n + # ------------------------------------------------------------------------ # + # Hashable path accessor + # ------------------------------------------------------------------------ # + + def get_path(self) -> NodePath: + return NodePath(tuple(self._path)) + # ------------------------------------------------------------------------ # # Navigation (implementation) # ------------------------------------------------------------------------ # diff --git a/src/exo/frontend/typecheck.py b/src/exo/frontend/typecheck.py index 9c89cb7e9..883c5b25b 100644 --- a/src/exo/frontend/typecheck.py +++ b/src/exo/frontend/typecheck.py @@ -97,6 +97,7 @@ def __init__(self, proc): self.uast_proc = proc self.env = dict() self.errors = [] + self.must_fuzz_reason = None args = [] for a in proc.args: diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index 86dad4381..877f8da90 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -410,7 +410,7 @@ def DoReorderStmt(f_cursor, s_cursor): do_check( lambda: Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node), lambda: fuzz_reorder_stmts(f_cursor, s_cursor), - "dynamic", + "both", ) ir, fwd = s_cursor._move(f_cursor.before()) return ir, fwd diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index c630f9165..d191c3b3f 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -402,7 +402,7 @@ def run_test_case( eval_info = js_eval(javascript)(*buffer_args) except Exception as e: raise Exception( - f"javascript:\n{javascript}\nproc:\n{transpiled_proc.proc}" + f"javascript:\n{javascript}\nproc:\n{transpiled_proc.get_proc()}" ) from e if transpiled_proc.get_coverage_skeleton() is None: [result, ctxt_object] = eval_info @@ -504,8 +504,13 @@ def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): spec1 = cur_scope.get_test_spec(cm) spec2 = transformed.get_test_spec(cm) - transpiled_test1 = Transpiler(spec1.proc, CoverageArgs(cm)) - transpiled_test2 = Transpiler(spec2.proc, CoverageArgs(cm)) + failure_scope = ( + starting_scope.as_block() + if isinstance(starting_scope, Node) + else starting_scope + ) + transpiled_test1 = Transpiler(spec1.proc, CoverageArgs(cm, failure_scope)) + transpiled_test2 = Transpiler(spec2.proc, CoverageArgs(cm, failure_scope)) config_fields = transpiled_test1.get_configs() | transpiled_test2.get_configs() diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index 086997ed8..164a33e6e 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -66,6 +66,27 @@ class LinearConstraint: class Expression: terms: tuple[ConstraintTerm, ...] + @staticmethod + def from_constant(const: int) -> "Expression": + return Expression((ConstraintTerm(const, ()),)) + + @staticmethod + def from_sym(sym: Sym) -> "Expression": + return Expression((ConstraintTerm(1, (sym,)),)) + + def negate(self) -> "Expression": + return Expression(tuple(term.negate() for term in self.terms)) + + def add(self, other: "Expression") -> "Expression": + return Expression((*self.terms, *other.terms)) + + def multiply(self, other: "Expression") -> "Expression": + return Expression( + tuple( + term1.multiply(term2) for term1 in self.terms for term2 in other.terms + ) + ) + def substitute(self, assignments: dict[Sym, int]) -> "Expression": coefficients: dict[tuple[Sym, ...], int] = {} for term in self.terms: @@ -132,10 +153,7 @@ def lift_to_disjoint_constraint(self) -> "DisjointConstraint": def invert(self) -> "DisjointConstraint": if self.has_slack: return Constraint( - Expression( - tuple(term.negate() for term in self.lhs.terms) - + (ConstraintTerm(-1, ()),) - ), + self.lhs.negate().add(Expression.from_constant(-1)), True, ).lift_to_disjoint_constraint() else: @@ -144,7 +162,7 @@ def invert(self) -> "DisjointConstraint": ConstraintClause( ( Constraint( - Expression(self.lhs.terms + (ConstraintTerm(-1, ()),)), + self.lhs.add(Expression.from_constant(-1)), True, ), ) @@ -152,10 +170,7 @@ def invert(self) -> "DisjointConstraint": ConstraintClause( ( Constraint( - Expression( - tuple(term.negate() for term in self.lhs.terms) - + (ConstraintTerm(-1, ()),) - ), + self.lhs.negate().add(Expression.from_constant(-1)), True, ), ) @@ -323,31 +338,29 @@ def __init__(self, type_map: dict[Sym, LoopIR.type]): if var_sub_result is not None: self.var_subs[sym] = var_sub_result + def get_var_sub(self, var_sym: Sym) -> Expression: + return self.var_subs[var_sym] + def make_var_sub(self, name: str, var_type: LoopIR.type) -> Optional[Expression]: if isinstance(var_type, (T.Size, T.Stride)): # positive variable - return Expression( - (ConstraintTerm(1, (Sym(f"{name}_m1"),)), ConstraintTerm(1, ())) + return Expression.from_sym(Sym(f"{name}_m1")).add( + Expression.from_constant(1) ) elif isinstance(var_type, (T.Int, T.Index)): # unsigned variables are represented as a - b, where a and b are nonnegative a, b = Sym(f"{name}_a"), Sym(f"{name}_b") - return Expression((ConstraintTerm(1, (a,)), ConstraintTerm(-1, (b,)))) + return Expression.from_sym(a).add(Expression.from_sym(b).negate()) elif isinstance(var_type, T.Bool): # constrained to [0, 1] sym = Sym(name) self.extra_constraints.append( Constraint( - Expression( - ( - ConstraintTerm(-1, (sym,)), - ConstraintTerm(1, ()), - ) - ), + Expression.from_sym(sym).negate().add(Expression.from_constant(1)), True, ) ) - return Expression((ConstraintTerm(1, (sym,)),)) + return Expression.from_sym(sym) else: return None @@ -361,60 +374,40 @@ def make_expression(self, expr: Union[LoopIR.expr, Sym]) -> Expression: ), "indexing not supported in assertions (yet, todo)" return self.var_subs[expr.name] elif isinstance(expr, LoopIR.Const): - return Expression((ConstraintTerm(expr.val, ()),)) + return Expression.from_constant(expr.val) elif isinstance(expr, LoopIR.USub): - return Expression( - tuple(term.negate() for term in self.make_expression(expr.arg).terms) - ) + return self.make_expression(expr.arg).negate() elif isinstance(expr, LoopIR.BinOp): # TODO: support mod and div using extra variables - lhs_terms = self.make_expression(expr.lhs).terms - rhs_terms = self.make_expression(expr.rhs).terms + lhs = self.make_expression(expr.lhs) + rhs = self.make_expression(expr.rhs) if expr.op == "+": - return Expression(lhs_terms + rhs_terms) + return lhs.add(rhs) elif expr.op == "-": - return Expression( - lhs_terms + tuple(term.negate() for term in rhs_terms) - ) + return lhs.add(rhs.negate()) elif expr.op == "*": - return Expression( - tuple( - lhs_term.multiply(rhs_term) - for lhs_term in lhs_terms - for rhs_term in rhs_terms - ) - ) + return lhs.multiply(rhs) elif expr.op in ["/", "%"]: div, rem = Sym("div"), Sym("rem") self.hidden_vars.update((div, rem)) self.extra_constraints.append( Constraint( - Expression( - tuple(lhs_term.negate() for lhs_term in lhs_terms) - + (ConstraintTerm(1, (rem,)),) - + tuple( - rhs_term.multiply(ConstraintTerm(1, (div,))) - for rhs_term in rhs_terms - ) - ), + lhs.negate() + .add(Expression.from_sym(rem)) + .add(rhs.multiply(Expression.from_sym(div))), False, ) ) self.extra_constraints.append( Constraint( - Expression( - ( - ConstraintTerm(-1, (rem,)), - ConstraintTerm(-1, ()), - ) - + rhs_terms - ), + Expression.from_sym(rem) + .add(Expression.from_constant(1)) + .negate() + .add(rhs), True, ) ) - return Expression( - (ConstraintTerm(1, (rem if expr.op == "%" else div,)),) - ) + return Expression.from_sym(rem if expr.op == "%" else div) else: assert False, f"unsupported op in assertion: {expr.op}" elif isinstance(expr, LoopIR.StrideExpr): @@ -422,7 +415,7 @@ def make_expression(self, expr: Union[LoopIR.expr, Sym]) -> Expression: new_sym = Sym("stride") self.stride_dummies[(expr.name, expr.dim)] = new_sym dummy = self.stride_dummies[(expr.name, expr.dim)] - return Expression((ConstraintTerm(1, (dummy,)),)) + return Expression.from_sym(dummy) elif isinstance(expr, LoopIR.ReadConfig): if (expr.config, expr.field) not in self.ctxt: field_type = expr.config.lookup_type(expr.field) @@ -460,12 +453,7 @@ def make_constraint( elif isinstance(expr, LoopIR.Read): assert len(expr.idx) == 0, "cannot index into boolean" return Constraint( - Expression( - ( - ConstraintTerm(1, (expr.name,)), - ConstraintTerm(-1, ()), - ) - ), + Expression.from_sym(expr.name).add(Expression.from_constant(-1)), True, ).lift_to_disjoint_constraint() elif isinstance(expr, LoopIR.Const): @@ -476,31 +464,24 @@ def make_constraint( def make_constraint_from_inequality( self, lhs: Union[LoopIR.expr, Sym], rhs: Union[LoopIR.expr, Sym], op: str ) -> Constraint: - lhs_terms = self.make_expression(lhs).terms - rhs_terms = self.make_expression(rhs).terms - has_slack = True + lhs_expr = self.make_expression(lhs) + rhs_expr = self.make_expression(rhs) if op == "<": - terms = ( - rhs_terms - + tuple(term.negate() for term in lhs_terms) - + (ConstraintTerm(-1, ()),) + return Constraint( + rhs_expr.add(lhs_expr.negate()).add(Expression.from_constant(-1)), True ) elif op == ">": - terms = ( - lhs_terms - + tuple(term.negate() for term in rhs_terms) - + (ConstraintTerm(-1, ()),) + return Constraint( + lhs_expr.add(rhs_expr.negate()).add(Expression.from_constant(-1)), True ) elif op == "<=": - terms = rhs_terms + tuple(term.negate() for term in lhs_terms) + return Constraint(rhs_expr.add(lhs_expr.negate()), True) elif op == ">=": - terms = lhs_terms + tuple(term.negate() for term in rhs_terms) + return Constraint(lhs_expr.add(rhs_expr.negate()), True) elif op == "==": - has_slack = False - terms = rhs_terms + tuple(term.negate() for term in lhs_terms) + return Constraint(lhs_expr.add(rhs_expr.negate()), False) else: assert False, "boolean ops expected" - return Constraint(Expression(terms), has_slack) def _make_solution_from_assignments(self, assignments: dict[Sym, int]) -> Solution: var_assignments = {} diff --git a/tests/golden/test_transpiler/test_matmul.txt b/tests/golden/test_transpiler/test_matmul.txt index 1f15cca5a..ea657ab2c 100644 --- a/tests/golden/test_transpiler/test_matmul.txt +++ b/tests/golden/test_transpiler/test_matmul.txt @@ -1,21 +1,15 @@ ((a_4,b_5,c_6)=>{ ctxt={} -if(!($size_a_4_0==$N_1))return [1,ctxt,{}]; -if(!($size_a_4_1==$K_3))return [1,ctxt,{}]; -if(!($size_b_5_0==$K_3))return [1,ctxt,{}]; -if(!($size_b_5_1==$M_2))return [1,ctxt,{}]; -if(!($size_c_6_0==$N_1))return [1,ctxt,{}]; -if(!($size_c_6_1==$M_2))return [1,ctxt,{}]; +if(!((($size_a_4_0)-(0))==($N_1))&&((($size_a_4_1)-(0))==($K_3)))return [1,ctxt,{}]; +if(!((($size_b_5_0)-(0))==($K_3))&&((($size_b_5_1)-(0))==($M_2)))return [1,ctxt,{}]; +if(!((($size_c_6_0)-(0))==($N_1))&&((($size_c_6_1)-(0))==($M_2)))return [1,ctxt,{}]; for(let i_7=0;i_7<$N_1;i_7++){ for(let j_8=0;j_8<$M_2;j_8++){ for(let k_9=0;k_9<$K_3;k_9++){ -if(!(i_7<$size_a_4_0&&i_7>=0))return [1,ctxt,{}]; -if(!(k_9<$size_a_4_1&&k_9>=0))return [1,ctxt,{}]; -if(!(k_9<$size_b_5_0&&k_9>=0))return [1,ctxt,{}]; -if(!(j_8<$size_b_5_1&&j_8>=0))return [1,ctxt,{}]; -if(!(i_7<$size_c_6_0&&i_7>=0))return [1,ctxt,{}]; -if(!(j_8<$size_c_6_1&&j_8>=0))return [1,ctxt,{}]; -c_6[Math.imul(i_7,$stride_c_6_0)+Math.imul(j_8,$stride_c_6_1)+0]+=(a_4[Math.imul(i_7,$stride_a_4_0)+Math.imul(k_9,$stride_a_4_1)+0]*b_5[Math.imul(k_9,$stride_b_5_0)+Math.imul(j_8,$stride_b_5_1)+0]); +if(!((((i_7)+(0)))>=(0)&&(((i_7)+(0)))<($size_a_4_0))&&((((k_9)+(0)))>=(0)&&(((k_9)+(0)))<($size_a_4_1)))return [1,ctxt,{}]; +if(!((((k_9)+(0)))>=(0)&&(((k_9)+(0)))<($size_b_5_0))&&((((j_8)+(0)))>=(0)&&(((j_8)+(0)))<($size_b_5_1)))return [1,ctxt,{}]; +if(!((((i_7)+(0)))>=(0)&&(((i_7)+(0)))<($size_c_6_0))&&((((j_8)+(0)))>=(0)&&(((j_8)+(0)))<($size_c_6_1)))return [1,ctxt,{}]; +c_6[Math.imul($stride_c_6_0,((i_7)+(0)))+Math.imul($stride_c_6_1,((j_8)+(0)))]+=(a_4[Math.imul($stride_a_4_0,((i_7)+(0)))+Math.imul($stride_a_4_1,((k_9)+(0)))]*b_5[Math.imul($stride_b_5_0,((k_9)+(0)))+Math.imul($stride_b_5_1,((j_8)+(0)))]); } } } diff --git a/tests/golden/test_transpiler/test_matmul_coverage.txt b/tests/golden/test_transpiler/test_matmul_coverage.txt index d7f3b98d9..6d8dd17df 100644 --- a/tests/golden/test_transpiler/test_matmul_coverage.txt +++ b/tests/golden/test_transpiler/test_matmul_coverage.txt @@ -1,44 +1,26 @@ ((a_4,b_5,c_6)=>{ ctxt={} -let body_23=false;let body_26=false;let body_29=false;let skip_24=false;let skip_27=false;let skip_30=false; -if(!($size_a_4_0==$N_1))return [1,ctxt,{}]; -if(!($size_a_4_1==$K_3))return [1,ctxt,{}]; -if(!($size_b_5_0==$K_3))return [1,ctxt,{}]; -if(!($size_b_5_1==$M_2))return [1,ctxt,{}]; -if(!($size_c_6_0==$N_1))return [1,ctxt,{}]; -if(!($size_c_6_1==$M_2))return [1,ctxt,{}]; +let body_23=false;let body_25=false;let body_27=false;let skip_24=false;let skip_26=false;let skip_28=false; + + +if(!((($size_a_4_0)-(0))==($N_1))&&((($size_a_4_1)-(0))==($K_3)))return [1,ctxt,{}]; +if(!((($size_b_5_0)-(0))==($K_3))&&((($size_b_5_1)-(0))==($M_2)))return [1,ctxt,{}]; +if(!((($size_c_6_0)-(0))==($N_1))&&((($size_c_6_1)-(0))==($M_2)))return [1,ctxt,{}]; body_23||=(0<$N_1);skip_24||=(0>=$N_1); for(let i_7=0;i_7<$N_1;i_7++){ - -body_26||=(0<$M_2);skip_27||=(0>=$M_2); +body_25||=(0<$M_2);skip_26||=(0>=$M_2); for(let j_8=0;j_8<$M_2;j_8++){ - -body_29||=(0<$K_3);skip_30||=(0>=$K_3); +body_27||=(0<$K_3);skip_28||=(0>=$K_3); for(let k_9=0;k_9<$K_3;k_9++){ -if(!(i_7<$size_c_6_0&&i_7>=0))return [1,ctxt,{}]; -if(!(j_8<$size_c_6_1&&j_8>=0))return [1,ctxt,{}]; - -if(!(i_7<$size_a_4_0&&i_7>=0))return [1,ctxt,{}]; -if(!(k_9<$size_a_4_1&&k_9>=0))return [1,ctxt,{}]; - - - -if(!(i_7<$size_a_4_0&&i_7>=0))return [1,ctxt,{}]; -if(!(k_9<$size_a_4_1&&k_9>=0))return [1,ctxt,{}]; -if(!(k_9<$size_b_5_0&&k_9>=0))return [1,ctxt,{}]; -if(!(j_8<$size_b_5_1&&j_8>=0))return [1,ctxt,{}]; - - -if(!(k_9<$size_b_5_0&&k_9>=0))return [1,ctxt,{}]; -if(!(j_8<$size_b_5_1&&j_8>=0))return [1,ctxt,{}]; +if(!((((i_7)+(0)))>=(0)&&(((i_7)+(0)))<($size_a_4_0))&&((((k_9)+(0)))>=(0)&&(((k_9)+(0)))<($size_a_4_1)))return [1,ctxt,{}]; +if(!((((k_9)+(0)))>=(0)&&(((k_9)+(0)))<($size_b_5_0))&&((((j_8)+(0)))>=(0)&&(((j_8)+(0)))<($size_b_5_1)))return [1,ctxt,{}]; -if(!(i_7<$size_c_6_0&&i_7>=0))return [1,ctxt,{}]; -if(!(j_8<$size_c_6_1&&j_8>=0))return [1,ctxt,{}]; -c_6[Math.imul(i_7,$stride_c_6_0)+Math.imul(j_8,$stride_c_6_1)+0]+=(a_4[Math.imul(i_7,$stride_a_4_0)+Math.imul(k_9,$stride_a_4_1)+0]*b_5[Math.imul(k_9,$stride_b_5_0)+Math.imul(j_8,$stride_b_5_1)+0]); +if(!((((i_7)+(0)))>=(0)&&(((i_7)+(0)))<($size_c_6_0))&&((((j_8)+(0)))>=(0)&&(((j_8)+(0)))<($size_c_6_1)))return [1,ctxt,{}]; +c_6[Math.imul($stride_c_6_0,((i_7)+(0)))+Math.imul($stride_c_6_1,((j_8)+(0)))]+=(a_4[Math.imul($stride_a_4_0,((i_7)+(0)))+Math.imul($stride_a_4_1,((k_9)+(0)))]*b_5[Math.imul($stride_b_5_0,((k_9)+(0)))+Math.imul($stride_b_5_1,((j_8)+(0)))]); } } } -return [0,ctxt,{body_23,body_26,body_29,skip_24,skip_27,skip_30}];}) \ No newline at end of file +return [0,ctxt,{body_23,body_25,body_27,skip_24,skip_26,skip_28}];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_nested_control_flow_coverage.txt b/tests/golden/test_transpiler/test_nested_control_flow_coverage.txt index 6040168ee..83489e76e 100644 --- a/tests/golden/test_transpiler/test_nested_control_flow_coverage.txt +++ b/tests/golden/test_transpiler/test_nested_control_flow_coverage.txt @@ -1,30 +1,25 @@ ((b_2)=>{ ctxt={} -let body_12=false;let false_case_17=false;let false_case_25=false;let skip_13=false;let true_case_16=false;let true_case_24=false;let write_20=false;let write_21=false;let write_26=false;let write_27=false; - +let access_18=false;let access_19=false;let access_22=false;let access_23=false;let body_12=false;let false_case_15=false;let false_case_21=false;let skip_13=false;let true_case_14=false;let true_case_20=false; body_12||=(0<$n_1);skip_13||=(0>=$n_1); for(let i_3=0;i_3<$n_1;i_3++){ - - if((i_3<(($n_1/2)|0))){ -true_case_16=true; -write_20=true; +true_case_14=true; +access_18=true; b_2[0]=2; }else{ -false_case_17=true; -write_21=true; +false_case_15=true; +access_19=true; b_2[0]=3; } - - if((i_3==($n_1-1))){ -true_case_24=true; -write_26=true; +true_case_20=true; +access_22=true; b_2[0]+=1; }else{ -false_case_25=true; -write_27=true; +false_case_21=true; +access_23=true; b_2[0]+=2; } } -return [0,ctxt,{body_12,false_case_17,false_case_25,skip_13,true_case_16,true_case_24,write_20,write_21,write_26,write_27}];}) \ No newline at end of file +return [0,ctxt,{access_18,access_19,access_22,access_23,body_12,false_case_15,false_case_21,skip_13,true_case_14,true_case_20}];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_variable_length_array_coverage.txt b/tests/golden/test_transpiler/test_variable_length_array_coverage.txt index e8605468b..54dba4997 100644 --- a/tests/golden/test_transpiler/test_variable_length_array_coverage.txt +++ b/tests/golden/test_transpiler/test_variable_length_array_coverage.txt @@ -1,24 +1,19 @@ (()=>{ ctxt={} -let body_9=false;let skip_10=false;let write_12=new ArrayBuffer(1,{maxByteLength:16});let write_15=new ArrayBuffer(1,{maxByteLength:16}); +let access_11=new ArrayBuffer(1,{maxByteLength:16});let access_13=new ArrayBuffer(1,{maxByteLength:16});let body_9=false;let skip_10=false; if(!($n_1>2))return [1,ctxt,{}]; - body_9||=(2<$n_1);skip_10||=(2>=$n_1); for(let i_2=2;i_2<$n_1;i_2++){ - -if(!(i_2>=0))return [1,ctxt,{}]; -let b_3=new Int32Array(i_2); -while(i_2>write_12.maxByteLength){let temp_13=new ArrayBuffer(write_12.byteLength,{maxByteLength:2*write_12.maxByteLength});for(let i=0;iwrite_15.maxByteLength){let temp_16=new ArrayBuffer(write_15.byteLength,{maxByteLength:2*write_15.maxByteLength});for(let i=0;i=0))return [1,ctxt,{}]; -if(!((i_2-1)=0))return [1,ctxt,{}]; -write_12[Math.imul((i_2-1),1)+0]=1; - -if(!((i_2-1)=0))return [1,ctxt,{}]; -b_3[Math.imul((i_2-1),1)+0]=0; -if(!((i_2-2)=0))return [1,ctxt,{}]; -write_15[Math.imul((i_2-2),1)+0]=1; - -if(!((i_2-2)=0))return [1,ctxt,{}]; -b_3[Math.imul((i_2-2),1)+0]=1; +let size_b_3_0=i_2; +let stride_b_3_0=1; +while(Math.imul(size_b_3_0,stride_b_3_0)>access_11.maxByteLength){let temp_12=new ArrayBuffer(access_11.byteLength,{maxByteLength:2*access_11.maxByteLength});for(let i=0;iaccess_13.maxByteLength){let temp_14=new ArrayBuffer(access_13.byteLength,{maxByteLength:2*access_13.maxByteLength});for(let i=0;i=0))return [1,ctxt,{}]; +let b_3=new Int32Array(Math.imul(size_b_3_0,stride_b_3_0)); +access_11[Math.imul((((i_2-1))+(0)),stride_b_3_0)]=1; +if(!(((((i_2-1))+(0)))>=(0)&&((((i_2-1))+(0)))<(size_b_3_0)))return [1,ctxt,{}]; +b_3[Math.imul(stride_b_3_0,(((i_2-1))+(0)))]=0; +access_13[Math.imul((((i_2-2))+(0)),stride_b_3_0)]=1; +if(!(((((i_2-2))+(0)))>=(0)&&((((i_2-2))+(0)))<(size_b_3_0)))return [1,ctxt,{}]; +b_3[Math.imul(stride_b_3_0,(((i_2-2))+(0)))]=1; } -return [0,ctxt,{body_9,skip_10,write_12,write_15}];}) \ No newline at end of file +return [0,ctxt,{access_11,access_13,body_9,skip_10}];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_window_coverage.txt b/tests/golden/test_transpiler/test_window_coverage.txt index 7873ddb58..ba116af67 100644 --- a/tests/golden/test_transpiler/test_window_coverage.txt +++ b/tests/golden/test_transpiler/test_window_coverage.txt @@ -1,14 +1,15 @@ ((a_1)=>{ ctxt={} - -if(!($size_a_1_0==16))return [1,ctxt,{}]; -if(!(0<=1&&1<=8&&8<=$size_a_1_0))return [1,ctxt,{}]; -if(!(3<$size_a_1_0&&3>=0))return [1,ctxt,{}]; - -if(!(3<$size_a_1_0&&3>=0))return [1,ctxt,{}]; -a_1[Math.imul(3,$stride_a_1_0)+0]=2; -if(!(2<(8-1)&&2>=0))return [1,ctxt,{}]; - -if(!(2<(8-1)&&2>=0))return [1,ctxt,{}]; -a_1[Math.imul(2,$stride_a_1_0)+(0+Math.imul(1,$stride_a_1_0))]=3; -return [0,ctxt,{}];}) \ No newline at end of file +let access_10=new ArrayBuffer(1,{maxByteLength:16});let access_8=new ArrayBuffer(1,{maxByteLength:16}); +while(Math.imul($size_a_1_0,$stride_a_1_0)>access_10.maxByteLength){let temp_11=new ArrayBuffer(access_10.byteLength,{maxByteLength:2*access_10.maxByteLength});for(let i=0;iaccess_8.maxByteLength){let temp_9=new ArrayBuffer(access_8.byteLength,{maxByteLength:2*access_8.maxByteLength});for(let i=0;i=(0)&&(((3)+(0)))<($size_a_1_0)))return [1,ctxt,{}]; +a_1[Math.imul($stride_a_1_0,((3)+(0)))]=2; +access_10[Math.imul(((2)+(lo_6)),$stride_a_1_0)]=1; +if(!((((2)+(lo_6)))>=(lo_6)&&(((2)+(lo_6)))<(hi_7)))return [1,ctxt,{}]; +a_1[Math.imul($stride_a_1_0,((2)+(lo_6)))]=3; +return [0,ctxt,{access_10,access_8}];}) \ No newline at end of file From 1dca899768c1eba2c57413be181191b96f72f891 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 13 May 2025 15:53:28 -0400 Subject: [PATCH 13/24] fix tests --- .../test_chexo/test_path_constraints.txt | 2 +- .../test_constraint_solver/test_divmod.txt | 4 ++-- .../test_divmod_solve.txt | 4 ++-- .../test_constraint_solver/test_inversion.txt | 24 +++++++++---------- .../test_constraint_solver/test_solve.txt | 4 ++-- tests/test_chexo.py | 2 +- 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/golden/test_chexo/test_path_constraints.txt b/tests/golden/test_chexo/test_path_constraints.txt index 5d6eee9b3..a5f8740bc 100644 --- a/tests/golden/test_chexo/test_path_constraints.txt +++ b/tests/golden/test_chexo/test_path_constraints.txt @@ -2,7 +2,7 @@ union( intersect( 1 * j_a + -1 * j_b + 0 >= 0, 1 * a_m1 + 1 + -1 * j_a + 1 * j_b + -1 >= 0, - -1 * a_m1 + -1 + 2 * i_a + -2 * i_b + 1 + -1 >= 0, + 1 * a_m1 + 1 + -2 * i_a + 2 * i_b + -1 >= 0, 1 * i_a + -1 * i_b + 0 >= 0, 1 * a_m1 + 1 + -1 * i_a + 1 * i_b + -1 >= 0, ), diff --git a/tests/golden/test_constraint_solver/test_divmod.txt b/tests/golden/test_constraint_solver/test_divmod.txt index 336c90602..bd2c9c003 100644 --- a/tests/golden/test_constraint_solver/test_divmod.txt +++ b/tests/golden/test_constraint_solver/test_divmod.txt @@ -2,11 +2,11 @@ union( intersect( 4 * a_m1 + 4 + 1 * b_m1 + 1 + -1 * c_m1 + -1 + -1 >= 0, 5 + -1 * b_m1 + -1 + -1 >= 0, - 3 + -1 * rem == 0, + 1 * rem + -3 == 0, ), intersect( 3 + -1 * a_m1 + -1 >= 0, 5 + -1 * b_m1 + -1 + -1 >= 0, - 3 + -1 * rem == 0, + 1 * rem + -3 == 0, ), ) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_divmod_solve.txt b/tests/golden/test_constraint_solver/test_divmod_solve.txt index f94895ebf..abd52673b 100644 --- a/tests/golden/test_constraint_solver/test_divmod_solve.txt +++ b/tests/golden/test_constraint_solver/test_divmod_solve.txt @@ -1,3 +1,3 @@ a = 15 -b = 4 -c = 15 \ No newline at end of file +b = 2 +c = 25 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_inversion.txt b/tests/golden/test_constraint_solver/test_inversion.txt index 2874affa4..3a38b12fd 100644 --- a/tests/golden/test_constraint_solver/test_inversion.txt +++ b/tests/golden/test_constraint_solver/test_inversion.txt @@ -5,34 +5,34 @@ union( ), intersect( -3 + 1 * a_m1 + 1 + -1 >= 0, - 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, ), intersect( -3 + 1 * a_m1 + 1 + -1 >= 0, - -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, ), intersect( - 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, -4 + 1 * b_m1 + 1 + -1 >= 0, ), intersect( - 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, - 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, ), intersect( - 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, - -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, ), intersect( - -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, -4 + 1 * b_m1 + 1 + -1 >= 0, ), intersect( - -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, - 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, ), intersect( - -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, - -4 + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -1 >= 0, + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, ), ) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_solve.txt b/tests/golden/test_constraint_solver/test_solve.txt index fabab028a..f91dff186 100644 --- a/tests/golden/test_constraint_solver/test_solve.txt +++ b/tests/golden/test_constraint_solver/test_solve.txt @@ -1,3 +1,3 @@ a = 22 -b = 2 -c = 16 \ No newline at end of file +b = 3 +c = 72 \ No newline at end of file diff --git a/tests/test_chexo.py b/tests/test_chexo.py index 6e0b0cc09..22c0ac3b2 100644 --- a/tests/test_chexo.py +++ b/tests/test_chexo.py @@ -49,7 +49,7 @@ def foo(a: size, b: f32[a]): free_vars = get_free_variables( type_visitor.type_map, type_visitor.mem_map, - [cursor._impl._node for cursor in foo.find("c: _").as_block().expand()], + foo.find("c: _")._impl.as_block().expand(), ) assert golden == stringify_dict(free_vars) From 3b83e7935e2f2dab923df184f187bfcb6bb38a31 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 13 May 2025 18:53:24 -0400 Subject: [PATCH 14/24] fix window scoping --- src/exo/backend/LoopIR_transpiler.py | 139 +++++++++++++++------------ src/exo/rewrite/chexo.py | 30 ++---- tests/test_constraint_solver.py | 14 +-- 3 files changed, 96 insertions(+), 87 deletions(-) diff --git a/src/exo/backend/LoopIR_transpiler.py b/src/exo/backend/LoopIR_transpiler.py index beda60a8c..45a1a1daf 100644 --- a/src/exo/backend/LoopIR_transpiler.py +++ b/src/exo/backend/LoopIR_transpiler.py @@ -698,17 +698,17 @@ def assign_scalar_from_context(self, scalar_sym: Sym, config: Config, field: str self.scalar_symbols[scalar_sym] = self.ctxt_symbols[config_key] def assign_window( - self, sym: Sym, window_expr: LoopIR.WindowExpr, in_bounds_js: str + self, sym: Sym, source_buf: Sym, access_cursor: Block, in_bounds_js: str ): - base_tensor = self.symbolic_tensors[window_expr.name] + base_tensor = self.symbolic_tensors[source_buf] in_bounds_constraint = TRUE_CONSTRAINT window_dims = [] - window_idx_iter = iter(window_expr.idx) + window_idx_iter = iter(access_cursor) for dim in base_tensor.dims: if isinstance(dim, SymbolicPoint): window_dims.append(dim) else: - idx = next(window_idx_iter) + idx = next(window_idx_iter)._node if isinstance(idx, LoopIR.Interval): new_dim = SymbolicSlice( self.cm.make_expression(idx.lo).add(dim.lower_bound), @@ -905,6 +905,7 @@ def _lookup_sym(self, sym: Sym) -> ExoValue: def _transpile_proc(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArgs]): self._buffer_args = [arg.name for arg in proc.args if arg.type.is_numeric()] + root_cursor = Cursor.create(proc) self._js_lines.append( f'(({",".join(repr(arg) for arg in self._buffer_args)})=>{{' ) @@ -912,9 +913,10 @@ def _transpile_proc(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArg if coverage_args is not None: self._coverage_state = CoverageState(coverage_args, self) arg_values = [] - for arg in proc.args: + for arg_cursor in root_cursor._child_block("args"): + arg = arg_cursor._node if arg.type.is_numeric(): - if arg.type.is_tensor_or_window(): + if isinstance(arg.type, LoopIR.Tensor): value = Tensor( arg.name, tuple( @@ -936,6 +938,12 @@ def _transpile_proc(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArg self._coverage_state.make_tensor( arg.name, arg.type.shape(), "true" ) + elif isinstance(arg.type, LoopIR.WindowType): + value = self._transpile_window( + arg.name, + arg.type.src_buf, + arg_cursor._child_node("type")._child_block("idx"), + ) else: value = Reference(repr(arg.name), False) if self._coverage_state is not None: @@ -943,7 +951,7 @@ def _transpile_proc(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArg else: value = Constant(f"${repr(arg.name)}") arg_values.append(value) - self._call_proc(Cursor.create(proc), tuple(arg_values), True) + self._call_proc(root_cursor, tuple(arg_values), True) coverage_object = "" if self._coverage_state is not None: skeleton = self._coverage_state.make_skeleton() @@ -958,6 +966,66 @@ def _transpile_proc(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArg ) self._js_lines[ctxt_placeholder] = f"{CONTEXT_OBJECT_NAME}={{{configs}}}" + def _transpile_window( + self, name: Sym, source_buf: Sym, access_cursor: Block + ) -> Tensor: + base = self._name_lookup[source_buf] + assert isinstance(base, Tensor) + window_dims = [] + in_bounds_conds = [] + idx_cursor_iter = iter(access_cursor) + for dim in base.dims: + if isinstance(dim.window_idx, Point): + window_dims.append(dim) + else: + idx_cursor = next(idx_cursor_iter) + idx = idx_cursor._node + if isinstance(idx, LoopIR.Interval): + lo_expr = self._transpile_expr( + idx_cursor._child_node("lo"), + ) + hi_expr = self._transpile_expr( + idx_cursor._child_node("hi"), + ) + lo_sym, hi_sym = Sym("lo"), Sym("hi") + self._js_lines.append( + f"let {repr(lo_sym)}=({lo_expr})+({dim.window_idx.lower_bound});" + ) + self._js_lines.append( + f"let {repr(hi_sym)}=({hi_expr})+({dim.window_idx.lower_bound});" + ) + in_bounds_conds.append( + f"(({dim.window_idx.lower_bound})<=({repr(lo_sym)})&&({repr(lo_sym)})<=({repr(hi_sym)})&&({repr(hi_sym)})<=({dim.window_idx.upper_bound}))" + ) + window_dims.append( + Dimension( + dim.size, dim.stride, Slice(repr(lo_sym), repr(hi_sym)) + ) + ) + elif isinstance(idx, LoopIR.Point): + pt_expr = self._transpile_expr( + idx_cursor._child_node("pt"), + ) + index_sym = Sym("idx") + self._js_lines.append( + f"let {repr(index_sym)}=({pt_expr})+({dim.window_idx.lower_bound});" + ) + in_bounds_conds.append( + f"(({dim.window_idx.lower_bound})<={repr(index_sym)}&&{repr(index_sym)}<({dim.window_idx.upper_bound}))" + ) + window_dims.append( + Dimension(dim.size, dim.stride, Point(repr(index_sym))) + ) + else: + assert False, "not a window index" + in_bounds_js = "&&".join(in_bounds_conds) + if self._coverage_state is not None: + self._coverage_state.assign_window( + name, source_buf, access_cursor, in_bounds_js + ) + self._assert_at_runtime(in_bounds_js) + return Tensor(base.name, tuple(window_dims)) + def _call_proc( self, proc_cursor: Node, arg_values: tuple[ExoValue, ...], top_level: bool ): @@ -1207,60 +1275,9 @@ def _transpile_buffer_arg( self._coverage_state.assign_scalar(new_name, expr.name) return buf elif isinstance(expr, LoopIR.WindowExpr): - base = self._name_lookup[expr.name] - assert isinstance(base, Tensor) - window_dims = [] - in_bounds_conds = [] - idx_cursor_iter = iter(expr_cursor._child_block("idx")) - for dim in base.dims: - if isinstance(dim.window_idx, Point): - window_dims.append(dim) - else: - idx_cursor = next(idx_cursor_iter) - idx = idx_cursor._node - if isinstance(idx, LoopIR.Interval): - lo_expr = self._transpile_expr( - idx_cursor._child_node("lo"), - ) - hi_expr = self._transpile_expr( - idx_cursor._child_node("hi"), - ) - lo_sym, hi_sym = Sym("lo"), Sym("hi") - self._js_lines.append( - f"let {repr(lo_sym)}=({lo_expr})+({dim.window_idx.lower_bound});" - ) - self._js_lines.append( - f"let {repr(hi_sym)}=({hi_expr})+({dim.window_idx.lower_bound});" - ) - in_bounds_conds.append( - f"(({dim.window_idx.lower_bound})<=({repr(lo_sym)})&&({repr(lo_sym)})<=({repr(hi_sym)})&&({repr(hi_sym)})<=({dim.window_idx.upper_bound}))" - ) - window_dims.append( - Dimension( - dim.size, dim.stride, Slice(repr(lo_sym), repr(hi_sym)) - ) - ) - elif isinstance(idx, LoopIR.Point): - pt_expr = self._transpile_expr( - idx_cursor._child_node("pt"), - ) - index_sym = Sym("idx") - self._js_lines.append( - f"let {repr(index_sym)}=({pt_expr})+({dim.window_idx.lower_bound});" - ) - in_bounds_conds.append( - f"(({dim.window_idx.lower_bound})<={repr(index_sym)}&&{repr(index_sym)}<({dim.window_idx.upper_bound}))" - ) - window_dims.append( - Dimension(dim.size, dim.stride, Point(repr(index_sym))) - ) - else: - assert False, "not a window index" - in_bounds_js = "&&".join(in_bounds_conds) - if self._coverage_state is not None: - self._coverage_state.assign_window(new_name, expr, in_bounds_js) - self._assert_at_runtime(in_bounds_js) - return Tensor(base.name, tuple(window_dims)) + return self._transpile_window( + new_name, expr.name, expr_cursor._child_block("idx") + ) elif isinstance(expr, LoopIR.ReadConfig): self._configs.add((expr.config, expr.field)) if self._coverage_state is not None: diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index d191c3b3f..75bd80756 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -128,21 +128,6 @@ class ReadWriteSyms: read_syms: set[Sym] -# @dataclass -# class LoopFlattener(LoopIRModifier): -# universal_var_types: dict[Sym, LoopIR.type] = field(default_factory=lambda: {}) -# loop_syms: Optional[ReadWriteSyms] = None - -# def visit(self, node): -# if isinstance(node, LoopIR.For): -# old_loop_syms = self.loop_syms -# new_node = self.visit_generic(node) -# self.loop_syms = old_loop_syms -# elif isinstance(node, LoopIR.Assign): -# elif isinstance(node, LoopIR.Reduce): -# elif isinstance(node, LoopIR.WriteConfig): - - @dataclass class Dimension: size: int @@ -336,7 +321,7 @@ def generate_test_case( arg_values[arg_name] = val for arg_name, arg_type in arg_types.items(): - if arg_type.is_numeric(): + if arg_type.is_numeric() and not isinstance(arg_type, LoopIR.WindowType): if arg_type.is_real_scalar(): shape = (1,) else: @@ -353,7 +338,7 @@ def generate_test_case( class TestResult: buffer_values: dict[Sym, np.ndarray] ctxt_object: dict[str, Union[int, float]] - coverage_result: Optional[dict[str, Union[bool, memoryview, float]]] + coverage_result: Optional[dict[str, Union[bool, memoryview]]] def run_test_case( @@ -472,9 +457,14 @@ def get_test_spec(self, cm: ConstraintMaker) -> TestSpec: self.scope, ).items() ] - args = [arg for arg in args if not arg.type.is_numeric()] + [ - arg for arg in args if arg.type.is_numeric() - ] + args = sorted( + args, + key=lambda arg: ( + (2 if isinstance(arg.type, LoopIR.WindowType) else 1) + if arg.type.is_numeric() + else 0 + ), + ) proc = LoopIR.proc( name=root_proc.name, diff --git a/tests/test_constraint_solver.py b/tests/test_constraint_solver.py index fd0ae7906..399c1a0f4 100644 --- a/tests/test_constraint_solver.py +++ b/tests/test_constraint_solver.py @@ -22,12 +22,14 @@ def solve_proc_assertion(p): cm = ConstraintMaker(p_type.type_map) constraint = cm.make_constraint(p._loopir_proc.preds[0]) return "\n".join( - [ - f"{str(sym)} = {val}" - for sym, val in cm.solve_constraint( - constraint, bound=100, search_limit=10, seed=13 - ).var_assignments.items() - ] + sorted( + [ + f"{str(sym)} = {val}" + for sym, val in cm.solve_constraint( + constraint, bound=100, search_limit=10, seed=13 + ).var_assignments.items() + ] + ) ) From 5c8af6b85c9d60667e546285ca0bc37a2dd96ac5 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 13 May 2025 21:55:27 -0400 Subject: [PATCH 15/24] fix linprog behavior and blas test --- requirements.txt | 2 +- setup.cfg | 2 +- src/exo/rewrite/constraint_solver.py | 8 ++++++++ .../test_constraint_solver/test_divmod_solve.txt | 4 ++-- tests/golden/test_constraint_solver/test_solve.txt | 4 ++-- tests/golden/test_constraint_solver/test_xnor.txt | 12 ++++++++++++ tests/test_constraint_solver.py | 9 +++++++++ 7 files changed, 35 insertions(+), 6 deletions(-) create mode 100644 tests/golden/test_constraint_solver/test_xnor.txt diff --git a/requirements.txt b/requirements.txt index d27a4d023..5a054cd8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,6 @@ asdl==0.1.5 build==1.2.2.post1 z3-solver==4.14.0.0 yapf==0.43.0 -scipy==1.6.2 +scipy==1.13.1 hsnf==0.3.16 pythonmonkey==1.1.0 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 9bb57c7d0..fc247a342 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ install_requires = build>=1.2.1 z3-solver>=4.13.0.0 yapf>=0.40.2 - scipy>=1.6.2 + scipy>=1.13.1 hsnf>=0.3.16 pythonmonkey>=1.1.0 diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index 164a33e6e..13424626c 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -446,6 +446,13 @@ def make_constraint( expr.lhs ), self.make_constraint(expr.rhs) return lhs_constraints.union(rhs_constraints) + elif expr.op == "==" and isinstance(expr.lhs.type, LoopIR.Bool): + lhs_constraints, rhs_constraints = self.make_constraint( + expr.lhs + ), self.make_constraint(expr.rhs) + return ( + lhs_constraints.invert().intersect(rhs_constraints.invert()) + ).union(lhs_constraints.intersect(rhs_constraints)) else: return self.make_constraint_from_inequality( expr.lhs, expr.rhs, expr.op @@ -578,6 +585,7 @@ def _solve_for_assignments( A_ub=upper_bound_matrix, b_ub=upper_bound_offset, bounds=(None, None), + method="highs", ) if not lp.success: return "infeasible" if len(assignments) == 0 else "failed" diff --git a/tests/golden/test_constraint_solver/test_divmod_solve.txt b/tests/golden/test_constraint_solver/test_divmod_solve.txt index abd52673b..95424a91d 100644 --- a/tests/golden/test_constraint_solver/test_divmod_solve.txt +++ b/tests/golden/test_constraint_solver/test_divmod_solve.txt @@ -1,3 +1,3 @@ -a = 15 +a = 83 b = 2 -c = 25 \ No newline at end of file +c = 26 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_solve.txt b/tests/golden/test_constraint_solver/test_solve.txt index f91dff186..03a44c85b 100644 --- a/tests/golden/test_constraint_solver/test_solve.txt +++ b/tests/golden/test_constraint_solver/test_solve.txt @@ -1,3 +1,3 @@ -a = 22 +a = 93 b = 3 -c = 72 \ No newline at end of file +c = 65 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_xnor.txt b/tests/golden/test_constraint_solver/test_xnor.txt new file mode 100644 index 000000000..efe7ef902 --- /dev/null +++ b/tests/golden/test_constraint_solver/test_xnor.txt @@ -0,0 +1,12 @@ +union( + intersect( + -1 * b_m1 + -1 + 1 * a_m1 + 1 + -1 >= 0, + -1 * a_m1 + -1 + 0 + -1 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), + intersect( + 1 * b_m1 + 1 + -1 * a_m1 + -1 >= 0, + 1 * a_m1 + 1 + 0 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), +) \ No newline at end of file diff --git a/tests/test_constraint_solver.py b/tests/test_constraint_solver.py index 399c1a0f4..b597ede32 100644 --- a/tests/test_constraint_solver.py +++ b/tests/test_constraint_solver.py @@ -87,6 +87,15 @@ def foo(a: size, b: size): assert golden == stringify_proc_constraint(foo) +def test_xnor(golden): + @proc + def foo(a: size, b: size): + assert (a <= b) == (a >= 0) and (a + b < 4) + pass + + assert golden == stringify_proc_constraint(foo) + + def test_inversion(golden): @proc def foo(a: size, b: size): From 3a8dab6bd949320aae8c0b5005ddba9c0a954782 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 14 May 2025 21:30:04 -0400 Subject: [PATCH 16/24] gonna do eval now --- src/exo/API.py | 1 + src/exo/API_scheduling.py | 6 +- src/exo/backend/LoopIR_transpiler.py | 11 +- src/exo/backend/coverage.py | 27 ++ src/exo/core/internal_cursors.py | 2 +- src/exo/rewrite/LoopIR_scheduling.py | 88 ++++-- src/exo/rewrite/chexo.py | 252 ++++++++++++++---- src/exo/rewrite/constraint_solver.py | 46 +++- .../test_variable_length_array_coverage.txt | 2 +- 9 files changed, 336 insertions(+), 99 deletions(-) diff --git a/src/exo/API.py b/src/exo/API.py index d00ff5de3..bd217ada8 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -188,6 +188,7 @@ def _forward(_): ) self._loopir_proc = proc + self._check_mode = "both" self._provenance_eq_Procedure = _provenance_eq_Procedure self._forward = _forward diff --git a/src/exo/API_scheduling.py b/src/exo/API_scheduling.py index c35bc280a..fcf2ca7b6 100644 --- a/src/exo/API_scheduling.py +++ b/src/exo/API_scheduling.py @@ -889,7 +889,7 @@ def reorder_stmts(proc, block_cursor): s1 = block_cursor[0]._impl s2 = block_cursor[1]._impl - ir, fwd = scheduling.DoReorderStmt(s1, s2) + ir, fwd = scheduling.DoReorderStmt(s1, s2, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1216,7 +1216,9 @@ def bind_config(proc, var_cursor, config, field): f"to match type of Config variable ({cfg_f_type})" ) - ir, fwd, cfg = scheduling.DoBindConfig(config, field, var_cursor._impl) + ir, fwd, cfg = scheduling.DoBindConfig( + config, field, var_cursor._impl, proc._check_mode + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd, _mod_config=cfg) diff --git a/src/exo/backend/LoopIR_transpiler.py b/src/exo/backend/LoopIR_transpiler.py index 45a1a1daf..e5fe9dd70 100644 --- a/src/exo/backend/LoopIR_transpiler.py +++ b/src/exo/backend/LoopIR_transpiler.py @@ -145,15 +145,17 @@ def access_tensor( js_tensor.name ].resize_placeholder if resize_placeholder is None: - decl_stmt = f"let {repr(access_sym)}=new ArrayBuffer({base_size});" + decl_stmt = f"let {repr(access_sym)}=new Uint8Array({base_size});" fillers = ( IndexedFiller(access_placeholder, mark_stmt), IndexedFiller(self.parent_state.cov_placeholder, decl_stmt), ) else: temp_sym = Sym("temp") - decl_stmt = f"let {repr(access_sym)}=new ArrayBuffer(1,{{maxByteLength:{INITIAL_DYNAMIC_SIZE}}});" - resize_stmt = f"while({base_size}>{repr(access_sym)}.maxByteLength){{let {repr(temp_sym)}=new ArrayBuffer({repr(access_sym)}.byteLength,{{maxByteLength:2*{repr(access_sym)}.maxByteLength}});for(let i=0;i<{repr(access_sym)}.byteLength;i++){repr(temp_sym)}[i]={repr(access_sym)}[i];{repr(access_sym)}={repr(temp_sym)}}};{repr(access_sym)}.resize(Math.max({base_size},{repr(access_sym)}.byteLength));" + decl_stmt = ( + f"let {repr(access_sym)}=new Uint8Array({INITIAL_DYNAMIC_SIZE});" + ) + resize_stmt = f"if({base_size}>{repr(access_sym)}.length){{let {repr(temp_sym)}=new Uint8Array(Math.max(2*{repr(access_sym)}.length,{base_size}));for(let i=0;i<{repr(access_sym)}.length;i++){repr(temp_sym)}[i]={repr(access_sym)}[i];{repr(access_sym)}={repr(temp_sym)}}};" fillers = ( IndexedFiller(access_placeholder, mark_stmt), IndexedFiller(self.parent_state.cov_placeholder, decl_stmt), @@ -1207,8 +1209,9 @@ def transpile_loop_body(): tensor_name, stmt.type.shape(), nonnegative_dims_js ) self._assert_at_runtime(nonnegative_dims_js) + buffer_size = f"Math.imul({self.get_size_param_name(tensor_name, 0)},{self.get_stride_param_name(tensor_name, 0)})" self._js_lines.append( - f"let {repr(tensor_name)}=new {buffer_type}(Math.imul({self.get_size_param_name(tensor_name, 0)},{self.get_stride_param_name(tensor_name, 0)}));" + f"let {repr(tensor_name)}=new {buffer_type}({buffer_size});for(let i=0;i<{buffer_size};i++){{{repr(tensor_name)}[i]=(Math.random()-0.5)*(1<<30);}}" ) self._name_lookup[stmt.name] = Tensor( stmt.name, diff --git a/src/exo/backend/coverage.py b/src/exo/backend/coverage.py index de93c13d1..2483c1176 100644 --- a/src/exo/backend/coverage.py +++ b/src/exo/backend/coverage.py @@ -27,6 +27,9 @@ def merge(self, other: "CoverageProgress") -> "CoverageProgress": self.total_cases + other.total_cases, ) + def is_finished(self) -> bool: + return self.covered_cases == self.total_cases + @dataclass class CoverageSolverState: @@ -291,6 +294,30 @@ def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: self.visited_nonaliasing = True if new_solution is not None: return state.update_solution(new_constraint, new_solution) + elif not self.visited_nonaliasing and not self.visited_aliasing: + ( + access1_path_constraint, + access1_indices, + _, + ) = self.access1.make_renamed_constraint_and_indices(state) + ( + access2_path_constraint, + access2_indices, + _, + ) = self.access2.make_renamed_constraint_and_indices(state) + path_constraints = access1_path_constraint.intersect( + access2_path_constraint + ) + new_constraint = state.current_constraint.intersect(path_constraints) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if new_solution is None: + self.visited_aliasing = True + self.visited_nonaliasing = True + else: + return state.update_solution(new_constraint, new_solution) + return state diff --git a/src/exo/core/internal_cursors.py b/src/exo/core/internal_cursors.py index 3abae4e29..9f2957790 100644 --- a/src/exo/core/internal_cursors.py +++ b/src/exo/core/internal_cursors.py @@ -583,7 +583,7 @@ def forward(cursor: Node): @dataclass(frozen=True) class NodePath: - path: tuple[tuple[str, Optional[int]]] + path: tuple[tuple[str, Optional[int]], ...] @dataclass diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index 877f8da90..5d61fbc09 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -1,7 +1,7 @@ import re from collections import ChainMap import traceback -from typing import Callable, List, Literal, Tuple, Optional +from typing import Any, Callable, List, Literal, Tuple, Optional from ..core.LoopIR import ( LoopIR, @@ -33,7 +33,7 @@ Check_ExprBound, Check_Aliasing, ) -from .chexo import fuzz_reorder_stmts +from .chexo import fuzz, fuzz_reorder_stmts from .range_analysis import IndexRangeEnvironment, IndexRange, index_range_analysis @@ -369,16 +369,20 @@ def divide_expr(e, quot): # Scheduling directives +CheckMode = Literal["static", "dynamic", "both"] + + def do_check( - static_check: Callable[[], None], - dynamic_check: Callable[[], None], - mode: Literal["static", "dynamic", "both"], -): + static_check: Callable[[], Any], + dynamic_check: Callable[[], Any], + mode: CheckMode, +) -> Any: if mode == "both": e_static, e_dynamic = None, None trb_static, trb_dynamic = None, None + static_res = None try: - static_check() + static_res = static_check() except Exception as e: e_static = e trb_static = traceback.format_exc() @@ -393,16 +397,18 @@ def do_check( ), f"fuzzer should match static analysis\ntrb_static: {trb_static}\n\ntrb_dynamic: {trb_dynamic}" elif e_static is not None: raise e_static + else: + return static_res elif mode == "static": - static_check() + return static_check() elif mode == "dynamic": - dynamic_check() + return dynamic_check() # Take a conservative approach and allow stmt reordering only when they are # writing to different buffers # TODO: Do effectcheck's check_commutes-ish thing using SMT here -def DoReorderStmt(f_cursor, s_cursor): +def DoReorderStmt(f_cursor, s_cursor, check_mode: CheckMode): if f_cursor.next() != s_cursor: raise SchedulingError( "expected the second statement to be directly after the first" @@ -410,7 +416,7 @@ def DoReorderStmt(f_cursor, s_cursor): do_check( lambda: Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node), lambda: fuzz_reorder_stmts(f_cursor, s_cursor), - "both", + check_mode, ) ir, fwd = s_cursor._move(f_cursor.before()) return ir, fwd @@ -1199,28 +1205,54 @@ def DoConfigWrite(stmt_cursor, config, field, expr, before=False): # Bind Expression scheduling directive -def DoBindConfig(config, field, expr_cursor): - e = expr_cursor._node - assert isinstance(e, LoopIR.Read) +def DoBindConfig(config, field, expr_cursor, check_mode): + def static_check(): + e = expr_cursor._node + assert isinstance(e, LoopIR.Read) - c = expr_cursor - while not isinstance(c._node, LoopIR.stmt): - c = c.parent() + c = expr_cursor + while not isinstance(c._node, LoopIR.stmt): + c = c.parent() - cfg_write_s = LoopIR.WriteConfig(config, field, e, e.srcinfo) - ir, fwd = c.before()._insert([cfg_write_s]) + cfg_write_s = LoopIR.WriteConfig(config, field, e, e.srcinfo) + ir, fwd = c.before()._insert([cfg_write_s]) - mod_cfg = Check_DeleteConfigWrite(ir, [cfg_write_s]) + mod_cfg = Check_DeleteConfigWrite(ir, [cfg_write_s]) - cfg_f_type = config.lookup_type(field) - cfg_read_e = LoopIR.ReadConfig(config, field, cfg_f_type, e.srcinfo) - if isinstance(expr_cursor.parent()._node, LoopIR.Call): - cfg_read_e = [cfg_read_e] - ir, fwd_repl = fwd(expr_cursor)._replace(cfg_read_e) - fwd = _compose(fwd_repl, fwd) + cfg_f_type = config.lookup_type(field) + cfg_read_e = LoopIR.ReadConfig(config, field, cfg_f_type, e.srcinfo) + if isinstance(expr_cursor.parent()._node, LoopIR.Call): + cfg_read_e = [cfg_read_e] + ir, fwd_repl = fwd(expr_cursor)._replace(cfg_read_e) + fwd = _compose(fwd_repl, fwd) - Check_Aliasing(ir) - return ir, fwd, mod_cfg + Check_Aliasing(ir) + return ir, fwd, mod_cfg + + def dynamic_check(): + e = expr_cursor._node + assert isinstance(e, LoopIR.Read) + + c = expr_cursor + while not isinstance(c._node, LoopIR.stmt): + c = c.parent() + + cfg_write_s = LoopIR.WriteConfig(config, field, e, e.srcinfo) + ir, fwd1 = c.before()._insert([LoopIR.Pass(e.srcinfo)]) + pass_cursor = fwd1(c).prev() + ir, fwd2 = pass_cursor._replace([cfg_write_s]) + new_expr_cursor = fwd2(fwd1(expr_cursor)) + + cfg_f_type = config.lookup_type(field) + cfg_read_e = LoopIR.ReadConfig(config, field, cfg_f_type, e.srcinfo) + if isinstance(expr_cursor.parent()._node, LoopIR.Call): + cfg_read_e = [cfg_read_e] + ir, fwd3 = new_expr_cursor._replace(cfg_read_e) + fwd = _compose(fwd3, _compose(fwd2, fwd1)) + fuzz(pass_cursor.as_block().expand(delta_hi=1), _compose(fwd3, fwd2)) + return ir, fwd, None + + return do_check(static_check, dynamic_check, check_mode) def DoCommuteExpr(expr_cursors): diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index 75bd80756..f0cadc0f9 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -1,9 +1,9 @@ from itertools import chain from typing import Callable, Literal, Optional, Union -from ..core.internal_cursors import Cursor, Block, Node +from ..core.internal_cursors import Cursor, Block, Node, NodePath -from ..backend.LoopIR_transpiler import CoverageArgs, Transpiler +from ..backend.LoopIR_transpiler import CoverageArgs, StageMemArgs, Transpiler from ..backend.coverage import CoverageSkeleton from ..core.configs import Config @@ -413,21 +413,65 @@ class TestSpec: proc: LoopIR.proc constraint: DisjointConstraint arg_types: dict[Sym, LoopIR.type] + original_scope: Block + + def forward_to_test(self, cursor: Block) -> Optional[Block]: + if cursor in self.original_scope: + return Block( + self.proc, + Node(self.proc, []), + "body", + range( + cursor._range.start - self.original_scope._range.start, + cursor._range.stop - self.original_scope._range.stop, + ), + ) + for node_idx, node in enumerate(self.original_scope): + if node.is_ancestor_of(cursor): + return Block( + self.proc, + Node( + self.proc, + [("body", node_idx)] + cursor._anchor._path[len(node._path) :], + ), + cursor._attr, + cursor._range, + ) + return None + + def backward_from_test(self, path: NodePath) -> NodePath: + assert path.path[0][1] is not None + return NodePath( + tuple(self.original_scope._anchor._path) + + ( + ( + self.original_scope._attr, + self.original_scope._range.start + path.path[0][1], + ), + ) + + path.path[1:] + ) @dataclass class TestScope: - scope: Union[Block, Node] - flatten_loops: bool + scope: Block def broaden(self) -> Optional["TestScope"]: - if self.scope.depth() == 0: - return TestScope(self.scope, False) if self.flatten_loops else None + if self.scope._anchor.depth() == 0: + new_scope = self.scope.expand() + if ( + new_scope._range.start == self.scope._range.start + and new_scope._range.stop == self.scope._range.stop + ): + return None + else: + return TestScope(new_scope) else: - return TestScope(self.scope.parent(), self.flatten_loops) + return TestScope(self.scope._anchor.as_block()) def transform(self, forward: Callable[[Cursor], Cursor]) -> "TestScope": - return TestScope(forward(self.scope), self.flatten_loops) + return TestScope(forward(self.scope)) def get_type_map(self) -> dict[Sym, LoopIR.type]: root_proc = self.scope.get_root() @@ -470,75 +514,165 @@ def get_test_spec(self, cm: ConstraintMaker) -> TestSpec: name=root_proc.name, args=args, preds=[], - body=( - [self.scope._node] - if isinstance(self.scope, Node) - else self.scope.resolve_all() - ), + body=(self.scope.resolve_all()), instr=None, srcinfo=root_proc.srcinfo, ) arg_types = {arg.name: arg.type for arg in args} - return TestSpec(proc, constraint, arg_types) + return TestSpec(proc, constraint, arg_types, self.scope) TEST_CASE_BOUND = 15 def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): - cur_scope = TestScope(starting_scope, True) - transformed = cur_scope.transform(fwd) - - cm = ConstraintMaker(cur_scope.get_type_map() | transformed.get_type_map()) - - spec1 = cur_scope.get_test_spec(cm) - spec2 = transformed.get_test_spec(cm) - - failure_scope = ( + starting_scope = ( starting_scope.as_block() if isinstance(starting_scope, Node) else starting_scope ) - transpiled_test1 = Transpiler(spec1.proc, CoverageArgs(cm, failure_scope)) - transpiled_test2 = Transpiler(spec2.proc, CoverageArgs(cm, failure_scope)) + failure_scope = starting_scope + failure_transformed_scope = fwd(failure_scope) + assert isinstance(failure_transformed_scope, Block) + cur_scope = TestScope(starting_scope) + + while cur_scope is not None: + transformed = cur_scope.transform(fwd) + cm = ConstraintMaker(cur_scope.get_type_map() | transformed.get_type_map()) + + spec1 = cur_scope.get_test_spec(cm) + spec2 = transformed.get_test_spec(cm) + + transpiled_test1 = Transpiler( + spec1.proc, CoverageArgs(cm, spec1.forward_to_test(failure_scope)) + ) + transpiled_test2 = Transpiler( + spec2.proc, + CoverageArgs(cm, spec2.forward_to_test(failure_transformed_scope)), + ) - config_fields = transpiled_test1.get_configs() | transpiled_test2.get_configs() + config_fields = transpiled_test1.get_configs() | transpiled_test2.get_configs() - arg_types = spec1.arg_types | spec2.arg_types - constraint = spec1.constraint.union(spec2.constraint) - skeleton1, skeleton2 = ( - transpiled_test1.get_coverage_skeleton(), - transpiled_test2.get_coverage_skeleton(), + arg_types = spec1.arg_types | spec2.arg_types + constraint = spec1.constraint.union(spec2.constraint) + skeleton1, skeleton2 = ( + transpiled_test1.get_coverage_skeleton(), + transpiled_test2.get_coverage_skeleton(), + ) + assert skeleton1 is not None and skeleton2 is not None + coverage_skeleton = skeleton1.merge(skeleton2) + tests_passed = True + while not coverage_skeleton.get_coverage_progress().is_finished(): + test_case = generate_test_case( + arg_types, + config_fields, + constraint, + coverage_skeleton, + cm, + ) + if test_case is None: + continue + + out1 = run_test_case(test_case, transpiled_test1) + out2 = run_test_case(test_case, transpiled_test2) + if out1 == "failed" or out2 == "failed": + tests_passed = False + break + assert out1.coverage_result is not None and out2.coverage_result is not None + coverage_skeleton.update_coverage( + out1.coverage_result | out2.coverage_result + ) + for buffer_name in out1.buffer_values.keys() & out2.buffer_values.keys(): + if not np.allclose( + out1.buffer_values[buffer_name], out2.buffer_values[buffer_name] + ): + tests_passed = False + break + if cur_scope.broaden() is not None: + for ctxt_name in out1.ctxt_object & out2.ctxt_object.keys(): + if not np.allclose( + out1.ctxt_object[ctxt_name], out2.ctxt_object[ctxt_name] + ): + tests_passed = False + break + if tests_passed: + return + else: + cur_scope = cur_scope.broaden() + raise SchedulingError("tests failed at broadest scope") + + +def fuzz_stage_mem( + starting_scope: Block, window_expr: LoopIR.WindowExpr +) -> set[NodePath]: + starting_scope = ( + starting_scope.as_block() + if isinstance(starting_scope, Node) + else starting_scope ) - assert skeleton1 is not None and skeleton2 is not None - coverage_skeleton = skeleton1.merge(skeleton2) - for _ in range(TEST_CASE_BOUND): - test_case = generate_test_case( - arg_types, - config_fields, - constraint, - coverage_skeleton, - cm, + failure_scope = starting_scope + stage_scope = failure_scope + cur_scope = TestScope(starting_scope) + + while cur_scope is not None: + cm = ConstraintMaker(cur_scope.get_type_map()) + + spec = cur_scope.get_test_spec(cm) + + forwarded_stage_scope = spec.forward_to_test(stage_scope) + assert forwarded_stage_scope is not None + transpiled_test = Transpiler( + spec.proc, + CoverageArgs( + cm, + spec.forward_to_test(failure_scope), + StageMemArgs(window_expr, forwarded_stage_scope), + ), ) - if test_case is None: - continue - - out1 = run_test_case(test_case, transpiled_test1) - out2 = run_test_case(test_case, transpiled_test2) - if out1 == "failed" or out2 == "failed": - raise SchedulingError("domain mismatch") - assert out1.coverage_result is not None and out2.coverage_result is not None - coverage_skeleton.update_coverage(out1.coverage_result | out2.coverage_result) - for buffer_name in out1.buffer_values.keys() & out2.buffer_values.keys(): - if not np.allclose( - out1.buffer_values[buffer_name], out2.buffer_values[buffer_name] - ): - raise SchedulingError("mismatch found") - for ctxt_name in out1.ctxt_object & out2.ctxt_object.keys(): - if not np.allclose( - out1.ctxt_object[ctxt_name], out2.ctxt_object[ctxt_name] - ): - raise SchedulingError("context mismatch found") + + config_fields = transpiled_test.get_configs() + + arg_types = spec.arg_types + constraint = spec.constraint + coverage_skeleton = transpiled_test.get_coverage_skeleton() + assert coverage_skeleton is not None + tests_passed = True + while not coverage_skeleton.get_coverage_progress().is_finished(): + test_case = generate_test_case( + arg_types, + config_fields, + constraint, + coverage_skeleton, + cm, + ) + if test_case is None: + continue + + out = run_test_case(test_case, transpiled_test) + if out == "failed": + tests_passed = False + break + assert out.coverage_result is not None + coverage_skeleton.update_coverage(out.coverage_result) + + if tests_passed: + overlapping_paths = set() + failed = False + for staging_overlap in coverage_skeleton.staging_overlaps: + if staging_overlap.has_disjoint_access and staging_overlap.has_overlap: + failed = True + break + elif staging_overlap.has_overlap: + overlapping_paths.add( + spec.backward_from_test(staging_overlap.access_cursor) + ) + if failed: + cur_scope = cur_scope.broaden() + else: + return overlapping_paths + else: + cur_scope = cur_scope.broaden() + raise SchedulingError("cannot stage due to window overlaps") def fuzz_reorder_stmts(s1: Node, s2: Node): diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index 13424626c..a8f461163 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -465,6 +465,22 @@ def make_constraint( ).lift_to_disjoint_constraint() elif isinstance(expr, LoopIR.Const): return TRUE_CONSTRAINT if expr.val else FALSE_CONSTRAINT + + elif isinstance(expr, LoopIR.ReadConfig): + if (expr.config, expr.field) not in self.ctxt: + field_type = expr.config.lookup_type(expr.field) + assert isinstance(field_type, LoopIR.Bool) + var_sub_result = self.make_var_sub( + f"{expr.config.name()}_{expr.field}", field_type + ) + assert ( + var_sub_result is not None + ), "constraints can only occur on control variables" + self.ctxt[(expr.config, expr.field)] = var_sub_result + return Constraint( + self.ctxt[(expr.config, expr.field)].add(Expression.from_constant(-1)), + False, + ).lift_to_disjoint_constraint() else: assert False, "only boolean expected" @@ -580,16 +596,38 @@ def _solve_for_assignments( (np.ones(m_nonslack) * bound - vec_f[:m_nonslack], vec_f), axis=0, ) + radius_row = np.zeros((1, m - k + 1)) + radius_row[0, -1] = -1 + upper_bound_matrix_with_radius = np.concatenate( + ( + np.concatenate( + ( + upper_bound_matrix, + np.linalg.norm(upper_bound_matrix, axis=1)[ + :, np.newaxis + ], + ), + axis=1, + ), + radius_row, + ), + axis=0, + ) + upper_bound_offset_with_radius = np.concatenate( + (upper_bound_offset, np.array([0])), axis=0 + ) + objective = np.zeros(m - k + 1) + objective[-1] = -1 lp = linprog( - np.zeros(m - k), - A_ub=upper_bound_matrix, - b_ub=upper_bound_offset, + objective, + A_ub=upper_bound_matrix_with_radius, + b_ub=upper_bound_offset_with_radius, bounds=(None, None), method="highs", ) if not lp.success: return "infeasible" if len(assignments) == 0 else "failed" - cur_y = lp.x + cur_y = lp.x[: m - k] har_iter = 50 last_int_y = None for _ in range(har_iter): diff --git a/tests/golden/test_transpiler/test_variable_length_array_coverage.txt b/tests/golden/test_transpiler/test_variable_length_array_coverage.txt index 54dba4997..7788ca7c3 100644 --- a/tests/golden/test_transpiler/test_variable_length_array_coverage.txt +++ b/tests/golden/test_transpiler/test_variable_length_array_coverage.txt @@ -8,7 +8,7 @@ let size_b_3_0=i_2; let stride_b_3_0=1; while(Math.imul(size_b_3_0,stride_b_3_0)>access_11.maxByteLength){let temp_12=new ArrayBuffer(access_11.byteLength,{maxByteLength:2*access_11.maxByteLength});for(let i=0;iaccess_13.maxByteLength){let temp_14=new ArrayBuffer(access_13.byteLength,{maxByteLength:2*access_13.maxByteLength});for(let i=0;i=0))return [1,ctxt,{}]; -let b_3=new Int32Array(Math.imul(size_b_3_0,stride_b_3_0)); +let b_3=new Int32Array(Math.imul(size_b_3_0,stride_b_3_0));for(let i=0;i=(0)&&((((i_2-1))+(0)))<(size_b_3_0)))return [1,ctxt,{}]; b_3[Math.imul(stride_b_3_0,(((i_2-1))+(0)))]=0; From 138486ca50956afbe7a016502701b2bcd97fd724 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Thu, 15 May 2025 06:11:38 -0400 Subject: [PATCH 17/24] i lied --- src/exo/API.py | 37 ++- src/exo/__init__.py | 2 + src/exo/backend/LoopIR_compiler.py | 52 ++- src/exo/backend/LoopIR_transpiler.py | 36 ++- src/exo/core/extern.py | 3 + src/exo/frontend/typecheck.py | 99 ++++-- src/exo/libs/externs.py | 118 ++++++- src/exo/rewrite/LoopIR_scheduling.py | 5 +- src/exo/rewrite/chexo.py | 173 +++++----- src/exo/rewrite/constraint_solver.py | 461 ++++++++++++++++++++------- 10 files changed, 723 insertions(+), 263 deletions(-) diff --git a/src/exo/API.py b/src/exo/API.py index bd217ada8..84f195993 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -22,17 +22,18 @@ # Moved to new file from .core.proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc from .frontend.pyparser import get_ast_from_python, Parser, get_parent_scope -from .frontend.typecheck import TypeChecker +from .frontend.typecheck import TypeChecker, CheckMode from . import API_cursors as C from .core import internal_cursors as IC +from .backend.LoopIR_compiler import DEFAULT_CHECK_MODE # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # # Top-level decorator -def proc(f, _instr=None) -> "Procedure": +def proc(f, _instr=None, _check_mode: Optional[CheckMode] = None) -> "Procedure": if not isinstance(f, types.FunctionType): raise TypeError("@proc decorator must be applied to a function") @@ -42,11 +43,11 @@ def proc(f, _instr=None) -> "Procedure": parser = Parser( body, src_info, - parent_scope=get_parent_scope(depth=3 if _instr else 2), + parent_scope=get_parent_scope(depth=3 if _instr or _check_mode else 2), instr=_instr, as_func=True, ) - return Procedure(parser.result()) + return Procedure(parser.result(), _check_mode=_check_mode) def instr(c_instr, c_global=""): @@ -84,6 +85,14 @@ def parse_config(cls): return parse_config(_cls) +def chexo(f): + return proc(f, _check_mode="dynamic") + + +def chexo_debug(f): + return proc(f, _check_mode="both") + + # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # # iPython Display Object @@ -150,7 +159,11 @@ def compile_procs(proc_list, basedir: Path, c_file: str, h_file: str): def compile_procs_to_strings(proc_list, h_file_name: str): assert isinstance(proc_list, list) assert all(isinstance(p, Procedure) for p in proc_list) - return run_compile([p._loopir_proc for p in proc_list], h_file_name) + return run_compile( + [p._loopir_proc for p in proc_list], + h_file_name, + "dynamic" if any(p._check_mode == "dynamic" for p in proc_list) else "static", + ) class Procedure(ProcedureBase): @@ -160,15 +173,18 @@ def __init__( _provenance_eq_Procedure: "Procedure" = None, _forward=None, _mod_config=None, + _check_mode: Optional[CheckMode] = None, ): super().__init__() _mod_config = _mod_config or frozenset() + self._check_mode = DEFAULT_CHECK_MODE if _check_mode is None else _check_mode if isinstance(proc, LoopIR.UAST.proc): - proc = TypeChecker(proc).get_loopir() - CheckBounds(proc) - Check_Aliasing(proc) + proc = TypeChecker(proc, self._check_mode).get_loopir() + if self._check_mode != "dynamic": + CheckBounds(proc) + Check_Aliasing(proc) assert isinstance(proc, LoopIR.LoopIR.proc) @@ -188,7 +204,6 @@ def _forward(_): ) self._loopir_proc = proc - self._check_mode = "both" self._provenance_eq_Procedure = _provenance_eq_Procedure self._forward = _forward @@ -295,7 +310,9 @@ def find_all(self, pattern): # ---------------------------------------------- # def c_code_str(self): - decls, defns = compile_to_strings("c_code_str", [self._loopir_proc]) + decls, defns = compile_to_strings( + "c_code_str", [self._loopir_proc], check_mode=self._check_mode + ) return decls + "\n" + defns def compile_c(self, directory: Path, filename: str): diff --git a/src/exo/__init__.py b/src/exo/__init__.py index 95fe0c050..d44c7aaa3 100644 --- a/src/exo/__init__.py +++ b/src/exo/__init__.py @@ -5,6 +5,8 @@ proc, instr, config, + chexo, + chexo_debug, ExoType, ) from .rewrite.LoopIR_scheduling import SchedulingError diff --git a/src/exo/backend/LoopIR_compiler.py b/src/exo/backend/LoopIR_compiler.py index 090a215de..72586a8f2 100644 --- a/src/exo/backend/LoopIR_compiler.py +++ b/src/exo/backend/LoopIR_compiler.py @@ -16,6 +16,8 @@ from .win_analysis import WindowAnalysis from ..rewrite.range_analysis import IndexRangeEnvironment +DEFAULT_CHECK_MODE = "both" + def sanitize_str(s): return re.sub(r"\W", "_", s) @@ -52,7 +54,11 @@ def sanitize_str(s): def lift_to_cir(e, range_env): assert e.type.is_indexable(), "why are you here?" - is_non_neg = lambda e: range_env.check_expr_bound(0, IndexRangeEnvironment.leq, e) + is_non_neg = lambda e: ( + False + if range_env is None + else range_env.check_expr_bound(0, IndexRangeEnvironment.leq, e) + ) if isinstance(e, LoopIR.Read): return CIR.Read(e.name, is_non_neg(e)) @@ -320,10 +326,10 @@ def window_struct(base_type, n_dims, is_const) -> WindowStruct: # top level compiler function called by tests! -def run_compile(proc_list, h_file_name: str): +def run_compile(proc_list, h_file_name: str, check_mode=None): file_stem = str(Path(h_file_name).stem) lib_name = sanitize_str(file_stem) - fwd_decls, body = compile_to_strings(lib_name, proc_list) + fwd_decls, body = compile_to_strings(lib_name, proc_list, check_mode) source = f'#include "{h_file_name}"\n\n{body}' @@ -360,7 +366,8 @@ def run_compile(proc_list, h_file_name: str): } -def compile_to_strings(lib_name, proc_list): +def compile_to_strings(lib_name, proc_list, check_mode=None): + check_mode = DEFAULT_CHECK_MODE if check_mode is None else check_mode # Get transitive closure of call-graph orig_procs = [id(p) for p in proc_list] @@ -407,12 +414,15 @@ def from_lines(x): else: is_public_decl = id(p) in orig_procs - p = ParallelAnalysis().run(p) + if check_mode != "dynamic": + p = ParallelAnalysis().run(p) p = PrecisionAnalysis().run(p) p = WindowAnalysis().apply_proc(p) p = MemoryAnalysis().run(p) - comp = Compiler(p, ctxt_name, is_public_decl=is_public_decl) + comp = Compiler( + p, ctxt_name, is_public_decl=is_public_decl, check_mode=check_mode + ) d, b = comp.comp_top() struct_defns |= comp.struct_defns() needed_helpers |= comp.needed_helpers() @@ -522,13 +532,15 @@ def _compile_context_struct(configs, lib_name): class Compiler: - def __init__(self, proc, ctxt_name, *, is_public_decl): + def __init__(self, proc, ctxt_name, *, is_public_decl, check_mode): assert isinstance(proc, LoopIR.proc) self.proc = proc self.ctxt_name = ctxt_name self.env = ChainMap() - self.range_env = IndexRangeEnvironment(proc, fast=False) + self.range_env = ( + None if check_mode == "dynamic" else IndexRangeEnvironment(proc, fast=False) + ) self.names = ChainMap() self.envtyp = dict() self.mems = dict() @@ -693,12 +705,14 @@ def new_varname(self, symbol, typ, mem=None): def push(self, only=None): if only is None: self.env = self.env.new_child() - self.range_env.enter_scope() + if self.range_env is not None: + self.range_env.enter_scope() self.names = self.names.new_child() self._tab = self._tab + " " elif only == "env": self.env = self.env.new_child() - self.range_env.enter_scope() + if self.range_env is not None: + self.range_env.enter_scope() self.names = self.names.new_child() elif only == "tab": self._tab = self._tab + " " @@ -707,7 +721,8 @@ def push(self, only=None): def pop(self): self.env = self.env.parents - self.range_env.exit_scope() + if self.range_env is not None: + self.range_env.exit_scope() self.names = self.names.parents self._tab = self._tab[:-2] @@ -894,11 +909,12 @@ def comp_s(self, s): hi = self.comp_e(s.hi) self.push(only="env") itr = self.new_varname(s.iter, typ=T.index) # allocate a new string - self.range_env.add_loop_iter( - s.iter, - s.lo, - s.hi, - ) + if self.range_env is not None: + self.range_env.add_loop_iter( + s.iter, + s.lo, + s.hi, + ) if isinstance(s.loop_mode, LoopIR.Par): self.add_line(f"#pragma omp parallel for") self.add_line(f"for (int_fast32_t {itr} = {lo}; {itr} < {hi}; {itr}++) {{") @@ -1035,7 +1051,9 @@ def comp_e(self, e, prec=0): rhs = self.comp_e(e.rhs, local_prec + 1) if int_div: - if self.range_env.check_expr_bound(0, IndexRangeEnvironment.leq, e): + if self.range_env is None or self.range_env.check_expr_bound( + 0, IndexRangeEnvironment.leq, e + ): # TODO: too many parens? return f"(({lhs}) / ({rhs}))" return self._call_static_helper("exo_floor_div", lhs, rhs) diff --git a/src/exo/backend/LoopIR_transpiler.py b/src/exo/backend/LoopIR_transpiler.py index e5fe9dd70..c145ccd62 100644 --- a/src/exo/backend/LoopIR_transpiler.py +++ b/src/exo/backend/LoopIR_transpiler.py @@ -477,6 +477,7 @@ def make_parallel_access_pairs(self) -> tuple[ParallelAccessPair, ...]: @dataclass class CoverageArgs: cm: ConstraintMaker + var_renaming: dict[Sym, Sym] failure_scope: Optional[Block] = None stage_mem_args: Optional[StageMemArgs] = None @@ -484,6 +485,7 @@ class CoverageArgs: class CoverageState: def __init__(self, args: CoverageArgs, parent_transpiler: "Transpiler"): self.cm: ConstraintMaker = args.cm + self.var_renaming: dict[Sym, Sym] = args.var_renaming self.parent_transpiler: Transpiler = parent_transpiler self.cov_placeholder: int = parent_transpiler._make_placeholder() self.root: CoverageSkeletonNode = CoverageSkeletonNode(None, None, ()) @@ -516,16 +518,18 @@ def enter_loop( skip_sym = Sym("skip") loop_entrance_placeholder = self.parent_transpiler._make_placeholder() body_constraint = ( - self.cm.make_constraint_from_inequality(stmt.lo, stmt.iter, "<=") + self.cm.make_constraint_from_inequality( + stmt.lo, stmt.iter, "<=", self.var_renaming + ) .lift_to_disjoint_constraint() .intersect( self.cm.make_constraint_from_inequality( - stmt.iter, stmt.hi, "<" + stmt.iter, stmt.hi, "<", self.var_renaming ).lift_to_disjoint_constraint() ) ) skip_constraint = self.cm.make_constraint_from_inequality( - stmt.lo, stmt.hi, ">=" + stmt.lo, stmt.hi, ">=", self.var_renaming ).lift_to_disjoint_constraint() parent_node = self.current_node body_child = CoverageSkeletonNode( @@ -576,7 +580,7 @@ def enter_if( true_sym = Sym("true_case") false_sym = Sym("false_case") true_placeholder = self.parent_transpiler._make_placeholder() - cond_constraint = self.cm.make_constraint(stmt.cond) + cond_constraint = self.cm.make_constraint(stmt.cond, self.var_renaming) true_node = CoverageSkeletonNode( true_sym, (parent_node, cond_constraint), @@ -642,7 +646,7 @@ def assert_shape_matches( ): match_cond = match_cond.intersect( Constraint( - self.cm.make_expression(shape_dim) + self.cm.make_expression(shape_dim, self.var_renaming) .negate() .add(tensor_dim.upper_bound) .add(tensor_dim.lower_bound.negate()), @@ -657,13 +661,15 @@ def assert_shape_matches( def assert_predicate(self, pred: LoopIR.expr, js_pred: str): if self.failure_tracker is not None: self.failure_tracker.add_assertion( - self.cm.make_constraint(pred), + self.cm.make_constraint(pred, self.var_renaming), js_pred, self.parent_transpiler._make_placeholder(), ) def make_tensor(self, sym: Sym, dims: list[LoopIR.expr], nonnegative_dims_js: str): - symbolic_dims = tuple(self.cm.make_expression(dim) for dim in dims) + symbolic_dims = tuple( + self.cm.make_expression(dim, self.var_renaming) for dim in dims + ) nonnegative_constraint = TRUE_CONSTRAINT for symbolic_dim in symbolic_dims: nonnegative_constraint = nonnegative_constraint.intersect( @@ -713,8 +719,12 @@ def assign_window( idx = next(window_idx_iter)._node if isinstance(idx, LoopIR.Interval): new_dim = SymbolicSlice( - self.cm.make_expression(idx.lo).add(dim.lower_bound), - self.cm.make_expression(idx.hi).add(dim.lower_bound), + self.cm.make_expression(idx.lo, self.var_renaming).add( + dim.lower_bound + ), + self.cm.make_expression(idx.hi, self.var_renaming).add( + dim.lower_bound + ), ) in_bounds_constraint = in_bounds_constraint.intersect( Constraint( @@ -728,7 +738,9 @@ def assign_window( window_dims.append(new_dim) else: new_dim = SymbolicPoint( - self.cm.make_expression(idx.pt).add(dim.lower_bound) + self.cm.make_expression(idx.pt, self.var_renaming).add( + dim.lower_bound + ) ) in_bounds_constraint = in_bounds_constraint.intersect( Constraint( @@ -771,7 +783,9 @@ def access_tensor( in_bounds_constraint = TRUE_CONSTRAINT for dim in symbolic_tensor.dims: if isinstance(dim, SymbolicSlice): - idx = self.cm.make_expression(next(idx_expr_iter)).add(dim.lower_bound) + idx = self.cm.make_expression( + next(idx_expr_iter), self.var_renaming + ).add(dim.lower_bound) in_bounds_constraint = in_bounds_constraint.intersect( Constraint( idx.negate().add(dim.upper_bound), True diff --git a/src/exo/core/extern.py b/src/exo/core/extern.py index 979689f6d..d4322c8e3 100644 --- a/src/exo/core/extern.py +++ b/src/exo/core/extern.py @@ -37,3 +37,6 @@ def transpile(self, args): def compile(self, args, prim_type): raise NotImplementedError() + + def express_in_constraints(self, args, out_sym): + raise NotImplementedError() diff --git a/src/exo/frontend/typecheck.py b/src/exo/frontend/typecheck.py index 883c5b25b..6bbfa8bd4 100644 --- a/src/exo/frontend/typecheck.py +++ b/src/exo/frontend/typecheck.py @@ -1,3 +1,4 @@ +from typing import Literal from ..core.LoopIR import ( T, UAST, @@ -10,7 +11,6 @@ from ..core.extern import Extern_Typecheck_Error from ..core.memory import * - # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # @@ -35,6 +35,10 @@ # The typechecker +def is_mutable_index(t: LoopIR.type): + return isinstance(t, (T.INT8, T.UINT8, T.UINT16, T.INT32)) + + def check_call_types(err_handler, args, call_args): for call_a, sig_a in zip(args, call_args): if call_a.type == T.err: @@ -46,6 +50,7 @@ def check_call_types(err_handler, args, call_args): "expected size or index type " "expression, " f"but got type {call_a.type}", + allowed_in_chexo=is_mutable_index(call_a.type), ) elif sig_a.type is T.bool: @@ -92,12 +97,15 @@ def check_call_types(err_handler, args, call_args): assert False, "bad argument type case" +CheckMode = Literal["static", "dynamic", "both"] + + class TypeChecker: - def __init__(self, proc): + def __init__(self, proc, check_mode: CheckMode): self.uast_proc = proc self.env = dict() self.errors = [] - self.must_fuzz_reason = None + self.check_mode = check_mode args = [] for a in proc.args: @@ -151,8 +159,9 @@ def __init__(self, proc): def get_loopir(self): return self.loopir_proc - def err(self, node, msg): - self.errors.append(f"{node.srcinfo}: {msg}") + def err(self, node, msg, *, allowed_in_chexo=False): + if not allowed_in_chexo or self.check_mode != "dynamic": + self.errors.append(f"{node.srcinfo}: {msg}") def check_stmts(self, body): assert len(body) > 0 or self.uast_proc.instr @@ -166,7 +175,11 @@ def check_access(self, node, nm, idx, lvalue=False): idx = [self.check_e(i, is_index=True) for i in idx] for i in idx: if i.type != T.err and not i.type.is_indexable(): - self.err(i, f"cannot index with expression of type '{i.type}'") + self.err( + i, + f"cannot index with expression of type '{i.type}'", + allowed_in_chexo=is_mutable_index(i.type), + ) # check compatibility with buffer type typ = self.env[nm] @@ -237,7 +250,11 @@ def check_single_stmt(self, stmt): if isinstance(stmt, (UAST.Assign, UAST.Reduce)): rhs = self.check_e(stmt.rhs) if rhs.type != T.err and not rhs.type.is_real_scalar(): - self.err(rhs, f"cannot assign/reduce a '{rhs.type}' type value") + self.err( + rhs, + f"cannot assign/reduce a '{rhs.type}' type value", + allowed_in_chexo=rhs.type.is_indexable(), + ) idx, typ = self.check_access(stmt, stmt.name, stmt.idx, lvalue=True) assert typ.is_real_scalar() or typ is T.err @@ -267,6 +284,8 @@ def check_single_stmt(self, stmt): rhs, f"expected a real scalar value, but " f"got an expression of type {rhs.type}", + allowed_in_chexo=is_mutable_index(ftyp) + and rhs.type.is_indexable(), ) elif ftyp.is_indexable(): if not rhs.type.is_indexable(): @@ -274,6 +293,7 @@ def check_single_stmt(self, stmt): rhs, f"expected an index or size type " f"expression, but got type {rhs.type}", + allowed_in_chexo=is_mutable_index(rhs.type), ) elif ftyp == T.bool: if rhs.type != T.bool: @@ -319,10 +339,18 @@ def check_single_stmt(self, stmt): lo = self.check_e(stmt.cond.lo, is_index=True) if lo.type != T.err and not lo.type.is_indexable(): - self.err(lo, "expected loop bound to be indexable.") + self.err( + lo, + "expected loop bound to be indexable.", + allowed_in_chexo=is_mutable_index(lo.type), + ) hi = self.check_e(stmt.cond.hi, is_index=True) if hi.type != T.err and not hi.type.is_indexable(): - self.err(hi, "expected loop bound to be indexable.") + self.err( + hi, + "expected loop bound to be indexable.", + allowed_in_chexo=is_mutable_index(hi.type), + ) body = self.check_stmts(stmt.body) if isinstance(stmt.cond, UAST.SeqRange): @@ -358,7 +386,11 @@ def check_w_access(self, e, orig_hi): if isinstance(e, UAST.Point): pt = self.check_e(e.pt, is_index=True) if pt.type != T.err and not pt.type.is_indexable(): - self.err(pt, f"cannot index with expression of type '{pt.type}'") + self.err( + pt, + f"cannot index with expression of type '{pt.type}'", + allowed_in_chexo=is_mutable_index(pt.type), + ) return LoopIR.Point(pt, e.srcinfo) elif isinstance(e, UAST.Interval): @@ -367,14 +399,22 @@ def check_w_access(self, e, orig_hi): else: lo = self.check_e(e.lo, is_index=True) if lo.type != T.err and not lo.type.is_indexable(): - self.err(lo, f"cannot index with expression of type '{lo.type}'") + self.err( + lo, + f"cannot index with expression of type '{lo.type}'", + allowed_in_chexo=is_mutable_index(lo.type), + ) if e.hi is None: hi = orig_hi else: hi = self.check_e(e.hi, is_index=True) if hi.type != T.err and not hi.type.is_indexable(): - self.err(hi, f"cannot index with expression of type '{hi.type}'") + self.err( + hi, + f"cannot index with expression of type '{hi.type}'", + allowed_in_chexo=is_mutable_index(hi.type), + ) return LoopIR.Interval(lo, hi, e.srcinfo) @@ -473,20 +513,33 @@ def check_e(self, e, is_index=False): operand, f"expected 'index' or 'size' argument to " f"comparison op: {e.op}", + allowed_in_chexo=operand.type.is_real_scalar(), ) typ = T.bool elif e.op in ("+", "-", "*", "/", "%"): if lhs.type.is_real_scalar(): if not rhs.type.is_real_scalar(): - self.err(rhs, "expected scalar type") - typ = T.err + self.err( + rhs, + "expected scalar type", + allowed_in_chexo=rhs.type.is_indexable(), + ) elif e.op == "%": - self.err(e, "cannot compute modulus of 'R' values") - typ = T.err - else: - typ = lhs.type + self.err( + e, + "cannot compute modulus of 'R' values", + allowed_in_chexo=is_mutable_index(lhs.type) + and is_mutable_index(rhs.type), + ) + typ = lhs.type elif rhs.type.is_real_scalar(): - self.err(lhs, "expected scalar type") + self.err( + lhs, + "expected scalar type", + allowed_in_chexo=is_mutable_index(lhs.type) + and is_mutable_index(rhs.type), + ) + typ = lhs.type elif lhs.type == T.bool or rhs.type == T.bool: node = lhs if lhs.type == T.bool else rhs self.err(node, "cannot perform arithmetic on 'bool' values") @@ -506,15 +559,15 @@ def check_e(self, e, is_index=False): self.err( rhs, "cannot divide or modulo by a " "non-constant value", + allowed_in_chexo=True, ) - typ = T.err elif rhs.val <= 0: self.err( rhs, "cannot divide or modulo by zero " "or a negative value", + allowed_in_chexo=True, ) - typ = T.err typ = lhs.type elif e.op == "*": @@ -528,8 +581,9 @@ def check_e(self, e, is_index=False): "cannot multiply two non-constant " "indexing/sizing expressions, since " "the result would be non-affine", + allowed_in_chexo=True, ) - typ = T.err + typ = lhs.type else: # + or - if lhs.type == T.index or rhs.type == T.index: typ = T.index @@ -618,6 +672,7 @@ def check_t(self, typ): h, "expected array size expression " "to have type 'size' or type 'index'", + allowed_in_chexo=is_mutable_index(h.type), ) return T.Tensor(hi, typ.is_window, sub_typ) else: diff --git a/src/exo/libs/externs.py b/src/exo/libs/externs.py index 155b14fde..028ca4d44 100644 --- a/src/exo/libs/externs.py +++ b/src/exo/libs/externs.py @@ -1,6 +1,9 @@ -from exo.core.extern import Extern, _EErr +from ..core.extern import Extern, _EErr import numpy as np +from ..rewrite.constraint_solver import Constraint, DisjointConstraint, Expression +from ..core.prelude import Sym + class _Sin(Extern): def __init__(self): @@ -71,6 +74,25 @@ def transpile(self, args): def compile(self, args, prim_type): return f"_relu_{prim_type}(({prim_type}){args[0]})" + def express_in_constraints( + self, args: tuple[Expression, ...], out_sym: Sym + ) -> DisjointConstraint: + result_expr = Expression.from_sym(out_sym) + return ( + Constraint(args[0], True) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + args[0].add(result_expr.negate()), False + ).lift_to_disjoint_constraint() + ) + .union( + Constraint(args[0].negate(), True) + .lift_to_disjoint_constraint() + .intersect(Constraint(result_expr, False).lift_to_disjoint_constraint()) + ) + ) + relu = _Relu() @@ -118,6 +140,31 @@ def transpile(self, args): def compile(self, args, prim_type): return f"_select_{prim_type}(({prim_type}){args[0]}, ({prim_type}){args[1]}, ({prim_type}){args[2]}, ({prim_type}){args[3]})" + def express_in_constraints( + self, args: tuple[Expression, ...], out_sym: Sym + ) -> DisjointConstraint: + result_expr = Expression.from_sym(out_sym) + return ( + Constraint( + args[1].add(args[0].add(Expression.from_constant(1)).negate()), True + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + args[2].add(result_expr.negate()), False + ).lift_to_disjoint_constraint() + ) + .union( + Constraint(args[0].add(args[1].negate()), True) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + args[3].add(result_expr.negate()), False + ).lift_to_disjoint_constraint() + ) + ) + ) + select = _Select() @@ -254,3 +301,72 @@ def compile(self, args, prim_type): sqrt = _Sqrt() + + +class _IntMin(Extern): + def __init__(self): + super().__init__("intmin") + + def typecheck(self, args): + if len(args) != 2: + raise _EErr(f"expected 2 arguments, got {len(args)}") + + for i in range(len(args)): + atyp = args[i].type + if not atyp.is_indexable() and not atyp.is_real_scalar(): + raise _EErr( + f"expected argument {i+1} to be a real scalar value or " + f"control flow value, but got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + s = ( + f"{prim_type} _intmin_{prim_type}({prim_type} x,{prim_type} v)" + " {\n" + " if (x < v) return x;\n" + " else return v;\n" + "}\n" + ) + return s + + def interpret(self, args): + x = args[0] + v = args[1] + if x < v: + return x + else: + return v + + def transpile(self, args): + return f"(({args[0]}<{args[1]})?{args[0]}:{args[1]})" + + def compile(self, args, prim_type): + return f"_intmin_{prim_type}(({prim_type}){args[0]}, ({prim_type}){args[1]})" + + def express_in_constraints( + self, args: tuple[Expression, ...], out_sym: Sym + ) -> DisjointConstraint: + result_expr = Expression.from_sym(out_sym) + return ( + Constraint( + args[1].add(args[0].add(Expression.from_constant(1)).negate()), True + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + args[0].add(result_expr.negate()), False + ).lift_to_disjoint_constraint() + ) + .union( + Constraint(args[0].add(args[1].negate()), True) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + args[1].add(result_expr.negate()), False + ).lift_to_disjoint_constraint() + ) + ) + ) + + +intmin = _IntMin diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index 5d61fbc09..bb4d0bffc 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -43,7 +43,7 @@ import exo.API as api from ..frontend.pattern_match import match_pattern from ..core.memory import DRAM -from ..frontend.typecheck import check_call_types +from ..frontend.typecheck import check_call_types, CheckMode from functools import partial @@ -369,9 +369,6 @@ def divide_expr(e, quot): # Scheduling directives -CheckMode = Literal["static", "dynamic", "both"] - - def do_check( static_check: Callable[[], Any], dynamic_check: Callable[[], Any], diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index f0cadc0f9..ceb39c4db 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -71,6 +71,14 @@ def visit(self, node): self.type_map[node.name] = node.type if node.mem: self.mem_map[node.name] = node.mem + elif isinstance(node, LoopIR.Call): + for arg_val, arg in zip(node.args, node.f.args): + if isinstance(arg.type, LoopIR.Tensor) and arg.type.is_window: + self.type_map[arg.name] = arg_val.type + else: + self.type_map[arg.name] = arg.type + for stmt in node.f.body: + self.visit_generic(stmt) else: self.visit_generic(node) @@ -86,6 +94,22 @@ def visit(self, node): self.visit_generic(node) +@dataclass +class ModifiedVariableVisitor(LoopIRVisitor): + type_map: dict[Sym, LoopIR.type] + modified_vars: set[Sym] = field(default_factory=lambda: set()) + + def visit(self, node): + if isinstance(node, (LoopIR.Assign, LoopIR.Reduce)): + node_type = self.type_map[node.name] + if isinstance(node_type, LoopIR.WindowType): + self.modified_vars.add(node_type.src_buf) + else: + self.modified_vars.add(node.name) + else: + self.visit_generic(node) + + class LoopIRModifier: def visit(self, node): return self.visit_generic(node) @@ -201,33 +225,57 @@ def eval_tensor_dimension( def collect_path_constraints( - cursor: Union[Block, Node], cm: ConstraintMaker + cursor: Union[Block, Node], cm: ConstraintMaker, type_map: dict[Sym, LoopIR.type] ) -> DisjointConstraint: - cur = cursor + if isinstance(cursor, Block): + cursor = cursor[0] + assert isinstance(cursor, Node) + last_attr, last_index = cursor._path[-1] + cur = cursor.parent() result = TRUE_CONSTRAINT - last_attr = None + var_renaming = {} while cur.depth() != 0: if isinstance(cur, Node): - last_attr = cur._path[-1] if isinstance(cur._node, LoopIR.For): + modified_variable_visitor = ModifiedVariableVisitor(type_map) + for stmt in cur._node.body: + modified_variable_visitor.visit(stmt) + for var_sym in modified_variable_visitor.modified_vars: + var_renaming[var_sym] = cm.copy_var(var_sym) result = result.intersect( cm.make_constraint_from_inequality( - cur._node.iter, cur._node.lo, ">=" + cur._node.iter, cur._node.lo, ">=", var_renaming ).lift_to_disjoint_constraint() ) result = result.intersect( cm.make_constraint_from_inequality( - cur._node.iter, cur._node.hi, "<" + cur._node.iter, cur._node.hi, "<", var_renaming ).lift_to_disjoint_constraint() ) elif isinstance(cur._node, LoopIR.If): - constraint = cm.make_constraint(cur._node.cond) - if isinstance(last_attr, tuple) and last_attr[0] == "orelse": + assert last_index is not None + modified_variable_visitor = ModifiedVariableVisitor(type_map) + for stmt, _ in zip(cur._node[last_attr], range(last_index)): + modified_variable_visitor.visit(stmt) + for var_sym in modified_variable_visitor.modified_vars: + var_renaming[var_sym] = cm.copy_var(var_sym) + constraint = cm.make_constraint(cur._node.cond, var_renaming) + if last_attr == "orelse": result = result.intersect(constraint.invert()) else: result = result.intersect(constraint) + last_attr, last_index = cur._path[-1] cur = cur.parent() + + assert last_index is not None + modified_variable_visitor = ModifiedVariableVisitor(type_map) + for stmt, _ in zip(cur._node.body, range(last_index)): + modified_variable_visitor.visit(stmt) + for var_sym in modified_variable_visitor.modified_vars: + var_renaming[var_sym] = cm.copy_var(var_sym) + for pred in cur._node.preds: + result = result.intersect(cm.make_constraint(pred, var_renaming)) return result @@ -414,6 +462,7 @@ class TestSpec: constraint: DisjointConstraint arg_types: dict[Sym, LoopIR.type] original_scope: Block + var_renaming: dict[Sym, Sym] def forward_to_test(self, cursor: Block) -> Optional[Block]: if cursor in self.original_scope: @@ -479,15 +528,17 @@ def get_type_map(self) -> dict[Sym, LoopIR.type]: proc_type_visitor.visit(root_proc) return proc_type_visitor.type_map - def get_test_spec(self, cm: ConstraintMaker) -> TestSpec: + def get_test_spec( + self, cm: ConstraintMaker, type_map: dict[Sym, LoopIR.type] + ) -> TestSpec: root_proc = self.scope.get_root() proc_type_visitor = TypeVisitor() proc_type_visitor.visit(root_proc) constraint = TRUE_CONSTRAINT - for pred in root_proc.preds: - constraint = constraint.intersect(cm.make_constraint(pred)) - constraint = constraint.intersect(collect_path_constraints(self.scope, cm)) + constraint = constraint.intersect( + collect_path_constraints(self.scope, cm, type_map) + ) args = [ LoopIR.fnarg( name=var, @@ -518,8 +569,16 @@ def get_test_spec(self, cm: ConstraintMaker) -> TestSpec: instr=None, srcinfo=root_proc.srcinfo, ) + modified_variable_visitor = ModifiedVariableVisitor(type_map) + modified_variable_visitor.visit(proc) arg_types = {arg.name: arg.type for arg in args} - return TestSpec(proc, constraint, arg_types, self.scope) + return TestSpec( + proc, + constraint, + arg_types, + self.scope, + {sym: cm.top_var(sym) for sym in modified_variable_visitor.modified_vars}, + ) TEST_CASE_BOUND = 15 @@ -535,20 +594,25 @@ def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): failure_transformed_scope = fwd(failure_scope) assert isinstance(failure_transformed_scope, Block) cur_scope = TestScope(starting_scope) + cur_type_map = cur_scope.get_type_map() + transformed_type_map = cur_scope.transform(fwd).get_type_map() while cur_scope is not None: transformed = cur_scope.transform(fwd) - cm = ConstraintMaker(cur_scope.get_type_map() | transformed.get_type_map()) + cm = ConstraintMaker(cur_type_map | transformed_type_map) - spec1 = cur_scope.get_test_spec(cm) - spec2 = transformed.get_test_spec(cm) + spec1 = cur_scope.get_test_spec(cm, cur_type_map) + spec2 = transformed.get_test_spec(cm, transformed_type_map) transpiled_test1 = Transpiler( - spec1.proc, CoverageArgs(cm, spec1.forward_to_test(failure_scope)) + spec1.proc, + CoverageArgs(cm, spec1.var_renaming, spec1.forward_to_test(failure_scope)), ) transpiled_test2 = Transpiler( spec2.proc, - CoverageArgs(cm, spec2.forward_to_test(failure_transformed_scope)), + CoverageArgs( + cm, spec2.var_renaming, spec2.forward_to_test(failure_transformed_scope) + ), ) config_fields = transpiled_test1.get_configs() | transpiled_test2.get_configs() @@ -602,79 +666,6 @@ def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): raise SchedulingError("tests failed at broadest scope") -def fuzz_stage_mem( - starting_scope: Block, window_expr: LoopIR.WindowExpr -) -> set[NodePath]: - starting_scope = ( - starting_scope.as_block() - if isinstance(starting_scope, Node) - else starting_scope - ) - failure_scope = starting_scope - stage_scope = failure_scope - cur_scope = TestScope(starting_scope) - - while cur_scope is not None: - cm = ConstraintMaker(cur_scope.get_type_map()) - - spec = cur_scope.get_test_spec(cm) - - forwarded_stage_scope = spec.forward_to_test(stage_scope) - assert forwarded_stage_scope is not None - transpiled_test = Transpiler( - spec.proc, - CoverageArgs( - cm, - spec.forward_to_test(failure_scope), - StageMemArgs(window_expr, forwarded_stage_scope), - ), - ) - - config_fields = transpiled_test.get_configs() - - arg_types = spec.arg_types - constraint = spec.constraint - coverage_skeleton = transpiled_test.get_coverage_skeleton() - assert coverage_skeleton is not None - tests_passed = True - while not coverage_skeleton.get_coverage_progress().is_finished(): - test_case = generate_test_case( - arg_types, - config_fields, - constraint, - coverage_skeleton, - cm, - ) - if test_case is None: - continue - - out = run_test_case(test_case, transpiled_test) - if out == "failed": - tests_passed = False - break - assert out.coverage_result is not None - coverage_skeleton.update_coverage(out.coverage_result) - - if tests_passed: - overlapping_paths = set() - failed = False - for staging_overlap in coverage_skeleton.staging_overlaps: - if staging_overlap.has_disjoint_access and staging_overlap.has_overlap: - failed = True - break - elif staging_overlap.has_overlap: - overlapping_paths.add( - spec.backward_from_test(staging_overlap.access_cursor) - ) - if failed: - cur_scope = cur_scope.broaden() - else: - return overlapping_paths - else: - cur_scope = cur_scope.broaden() - raise SchedulingError("cannot stage due to window overlaps") - - def fuzz_reorder_stmts(s1: Node, s2: Node): starting_scope = s1.as_block().expand(0, 1) _, fwd = s2._move(s1.before()) diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index a8f461163..444531f68 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -1,26 +1,109 @@ from dataclasses import dataclass, field -from typing import Literal, Union, Optional +from typing import Callable, Literal, Union, Optional from ..core.configs import Config from ..core.prelude import Sym from ..core.LoopIR import LoopIR, T +from ..core.extern import Extern import numpy as np from scipy.optimize import linprog from hsnf import smith_normal_form import textwrap +@dataclass +class IndexTerm: + buffer_sym: Sym + indices: tuple["Expression", ...] + register_new_index: Callable[[tuple[int, ...]], Sym] + + def substitute(self, assignments: dict[Sym, int]) -> Union["IndexTerm", Sym]: + new_indices = tuple(index.substitute(assignments) for index in self.indices) + int_indices = [] + trivial = True + for new_index in new_indices: + trivial_val = new_index.get_trivial_result() + if trivial_val is None: + trivial = False + break + else: + int_indices.append(trivial_val) + if trivial: + return self.register_new_index(tuple(int_indices)) + return IndexTerm(self.buffer_sym, new_indices, self.register_new_index) + + def collect_nonlinear_syms(self) -> frozenset[Sym]: + return frozenset().union(*(index.collect_syms() for index in self.indices)) + + def collect_syms(self) -> frozenset[Sym]: + return self.collect_nonlinear_syms() + + def pretty_print(self) -> str: + index_str = ",".join(index.pretty_print() for index in self.indices) + return f"{str(self.buffer_sym)}[{index_str}]" + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "IndexTerm": + return IndexTerm( + self.buffer_sym, + tuple(index.rename_syms(lookup) for index in self.indices), + self.register_new_index, + ) + + +@dataclass +class ExternTerm: + extern: Extern + args: tuple["Expression", ...] + + def substitute(self, assignments: dict[Sym, int]) -> Union["ExternTerm", int]: + new_args = tuple(arg.substitute(assignments) for arg in self.args) + int_args = [] + trivial = True + for new_arg in new_args: + trivial_val = new_arg.get_trivial_result() + if trivial_val is None: + trivial = False + break + else: + int_args.append(trivial_val) + if trivial: + return self.extern.interpret(tuple(int_args)) + return ExternTerm(self.extern, new_args) + + def collect_nonlinear_syms(self) -> frozenset[Sym]: + return frozenset().union(*(arg.collect_syms() for arg in self.args)) + + def collect_syms(self) -> frozenset[Sym]: + return self.collect_nonlinear_syms() + + def pretty_print(self) -> str: + arg_str = ",".join(arg.pretty_print() for arg in self.args) + return f"{str(self.extern)}({arg_str})" + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "ExternTerm": + return ExternTerm( + self.extern, + tuple(arg.rename_syms(lookup) for arg in self.args), + ) + + +FunctionTerm = Union[IndexTerm, ExternTerm] + + @dataclass class ConstraintTerm: coefficient: int syms: tuple[Sym, ...] + functions: tuple[FunctionTerm, ...] def negate(self) -> "ConstraintTerm": - return ConstraintTerm(-self.coefficient, self.syms) + return ConstraintTerm(-self.coefficient, self.syms, self.functions) def multiply(self, other) -> "ConstraintTerm": return ConstraintTerm( - self.coefficient * other.coefficient, self.syms + other.syms + self.coefficient * other.coefficient, + self.syms + other.syms, + self.functions + other.functions, ) def substitute(self, assignments: dict[Sym, int]) -> "ConstraintTerm": @@ -31,7 +114,20 @@ def substitute(self, assignments: dict[Sym, int]) -> "ConstraintTerm": new_coefficient *= assignments[sym] else: new_syms.append(sym) - return ConstraintTerm(new_coefficient, tuple(new_syms)) + new_functions = [] + for function in self.functions: + sub = function.substitute(assignments) + if isinstance(sub, int): + new_coefficient *= sub + elif isinstance(sub, Sym): + new_syms.append(sub) + else: + new_functions.append(sub) + return ConstraintTerm( + new_coefficient, + tuple(new_syms), + tuple(new_functions), + ) def collect_nonlinear_syms(self) -> frozenset[Sym]: occurrences = set() @@ -43,15 +139,19 @@ def collect_nonlinear_syms(self) -> frozenset[Sym]: occurrences.add(sym) return frozenset(result) - def pretty_print(self) -> str: - return ( - f"{' * '.join([str(self.coefficient)] + [str(sym) for sym in self.syms])}" + def collect_syms(self) -> frozenset[Sym]: + return frozenset(self.syms).union( + *(function.collect_syms() for function in self.functions) ) + def pretty_print(self) -> str: + return f"{' * '.join([str(self.coefficient)] + [str(sym) for sym in self.syms] + [function.pretty_print() for function in self.functions])}" + def rename_syms(self, lookup: dict[Sym, Sym]) -> "ConstraintTerm": return ConstraintTerm( self.coefficient, tuple(lookup[sym] if sym in lookup else sym for sym in self.syms), + tuple(function.rename_syms(lookup) for function in self.functions), ) @@ -61,47 +161,77 @@ class LinearConstraint: offset: int has_slack: bool + def get_trivial_result(self) -> Optional[bool]: + if len(self.coefficients) > 0: + return None + return (self.offset >= 0 and self.has_slack) or self.offset == 0 + @dataclass class Expression: - terms: tuple[ConstraintTerm, ...] + terms: Optional[tuple[ConstraintTerm, ...]] @staticmethod def from_constant(const: int) -> "Expression": - return Expression((ConstraintTerm(const, ()),)) + return Expression((ConstraintTerm(const, (), ()),)) @staticmethod def from_sym(sym: Sym) -> "Expression": - return Expression((ConstraintTerm(1, (sym,)),)) + return Expression((ConstraintTerm(1, (sym,), ()),)) + + @staticmethod + def unsolvable() -> "Expression": + return Expression(None) + + @staticmethod + def from_function(function_term: FunctionTerm) -> "Expression": + return Expression((ConstraintTerm(1, (), (function_term,)),)) def negate(self) -> "Expression": - return Expression(tuple(term.negate() for term in self.terms)) + return Expression( + None if self.terms is None else tuple(term.negate() for term in self.terms) + ) def add(self, other: "Expression") -> "Expression": - return Expression((*self.terms, *other.terms)) + return Expression( + None + if self.terms is None or other.terms is None + else (*self.terms, *other.terms) + ) def multiply(self, other: "Expression") -> "Expression": return Expression( - tuple( + None + if self.terms is None or other.terms is None + else tuple( term1.multiply(term2) for term1 in self.terms for term2 in other.terms ) ) def substitute(self, assignments: dict[Sym, int]) -> "Expression": + if self.terms is None: + return self coefficients: dict[tuple[Sym, ...], int] = {} + other_terms: list[ConstraintTerm] = [] for term in self.terms: sub_term = term.substitute(assignments) - if sub_term.syms not in coefficients: - coefficients[sub_term.syms] = 0 - coefficients[sub_term.syms] += sub_term.coefficient + if len(sub_term.functions) != 0: + other_terms.append(sub_term) + else: + if sub_term.syms not in coefficients: + coefficients[sub_term.syms] = 0 + coefficients[sub_term.syms] += sub_term.coefficient return Expression( tuple( - ConstraintTerm(coefficient, syms) + ConstraintTerm(coefficient, syms, ()) for syms, coefficient in coefficients.items() ) + + tuple(other_terms) ) def get_trivial_result(self) -> Optional[int]: + if self.terms is None: + return None if len(self.terms) == 0: return 0 elif len(self.terms) == 1 and len(self.terms[0].syms) == 0: @@ -109,17 +239,25 @@ def get_trivial_result(self) -> Optional[int]: return None def collect_syms(self) -> frozenset[Sym]: - return frozenset(sym for term in self.terms for sym in term.syms) + if self.terms is None: + return frozenset() + return frozenset().union(*(term.collect_syms() for term in self.terms)) def collect_nonlinear_syms(self) -> frozenset[Sym]: + if self.terms is None: + return frozenset() return frozenset().union( *[term.collect_nonlinear_syms() for term in self.terms] ) def pretty_print(self): + if self.terms is None: + return "unsolvable" return " + ".join([term.pretty_print() for term in self.terms]) def rename_syms(self, lookup: dict[Sym, Sym]) -> "Expression": + if self.terms is None: + return self return Expression(tuple(term.rename_syms(lookup) for term in self.terms)) @@ -130,10 +268,14 @@ class Constraint: def linearize(self, assignments: dict[Sym, int]) -> Optional[LinearConstraint]: new_lhs = self.lhs.substitute(assignments) + if new_lhs.terms is None: + return None offset = 0 coefficients = {} for term in new_lhs.terms: - if len(term.syms) == 0: + if len(term.functions) != 0: + return None + elif len(term.syms) == 0: offset += term.coefficient elif len(term.syms) == 1: coefficients[term.syms[0]] = term.coefficient @@ -190,6 +332,9 @@ def get_trivial_result(self) -> Optional[bool]: return (lhs_result >= 0 and self.has_slack) or lhs_result == 0 return None + def is_unsolvable(self) -> bool: + return self.lhs.terms is None + def rename_syms(self, lookup: dict[Sym, Sym]) -> "Constraint": return Constraint(self.lhs.rename_syms(lookup), self.has_slack) @@ -312,32 +457,40 @@ def rename_syms(self, lookup: dict[Sym, Sym]) -> "DisjointConstraint": class Solution: ctxt: dict[tuple[Config, str], int] var_assignments: dict[Sym, int] + buffer_assignments: dict[tuple[Sym, tuple[int, ...]], int] substitutions: dict[Sym, int] - def merge_solutions(self, other: "Solution", other_renaming: dict[Sym, Sym]): - return Solution( - self.ctxt, - self.var_assignments, - { - other_renaming[key] if key in other_renaming else key: value - for key, value in other.substitutions.items() - } - | self.substitutions, - ) - class ConstraintMaker: def __init__(self, type_map: dict[Sym, LoopIR.type]): self.var_subs: dict[Sym, Expression] = {} self.ctxt: dict[tuple[Config, str], Expression] = {} - self.extra_constraints: list[Constraint] = [] - self.hidden_vars: set[Sym] = set() + self.extra_constraints: dict[Sym, DisjointConstraint] = {} self.stride_dummies: dict[tuple[Sym, int], Sym] = {} + self.buffer_syms: dict[Sym, tuple[Sym, tuple[int, ...]]] = {} + self.unbound_buffers: set[Sym] = set() + self.type_map = type_map for sym, sym_type in type_map.items(): var_sub_result = self.make_var_sub(sym.name(), sym_type) if var_sub_result is not None: self.var_subs[sym] = var_sub_result + def copy_var(self, sym: Sym) -> Sym: + new_sym = Sym(f"{sym.name()}_copy") + var_sub_result = self.make_var_sub(new_sym.name(), self.type_map[sym]) + if var_sub_result is not None: + self.var_subs[new_sym] = var_sub_result + return new_sym + + def top_var(self, sym: Sym) -> Sym: + new_sym = Sym(f"{sym.name()}_top") + sym_type = self.type_map[sym] + if sym_type.is_tensor_or_window(): + self.unbound_buffers.add(new_sym) + else: + self.var_subs[new_sym] = Expression.unsolvable() + return new_sym + def get_var_sub(self, var_sym: Sym) -> Expression: return self.var_subs[var_sym] @@ -354,33 +507,78 @@ def make_var_sub(self, name: str, var_type: LoopIR.type) -> Optional[Expression] elif isinstance(var_type, T.Bool): # constrained to [0, 1] sym = Sym(name) - self.extra_constraints.append( - Constraint( - Expression.from_sym(sym).negate().add(Expression.from_constant(1)), - True, - ) - ) + self.extra_constraints[sym] = Constraint( + Expression.from_sym(sym).negate().add(Expression.from_constant(1)), + True, + ).lift_to_disjoint_constraint() return Expression.from_sym(sym) else: return None - def make_expression(self, expr: Union[LoopIR.expr, Sym]) -> Expression: + def register_buffer_index(self, indices: tuple[int, ...], buffer_sym: Sym) -> Sym: + sym = Sym("buf") + self.buffer_syms[sym] = (buffer_sym, indices) + return sym + + def make_expression( + self, expr: Union[LoopIR.expr, Sym], var_renaming: dict[Sym, Sym] + ) -> Expression: # expect that expr is int type if isinstance(expr, Sym): + if expr in var_renaming: + return self.var_subs[var_renaming[expr]] return self.var_subs[expr] elif isinstance(expr, LoopIR.Read): - assert ( - len(expr.idx) == 0 - ), "indexing not supported in assertions (yet, todo)" - return self.var_subs[expr.name] + if len(expr.idx) == 0: + return self.var_subs[expr.name] + else: + buf_type = self.type_map[expr.name] + if isinstance(buf_type, LoopIR.Tensor): + buf_name = expr.name + index_exprs = tuple( + self.make_expression(idx, var_renaming) for idx in expr.idx + ) + elif isinstance(buf_type, LoopIR.WindowType): + buf_name = buf_type.src_buf + index_list: list[Expression] = [] + expr_idx_iter = iter(expr.idx) + for idx in buf_type.idx: + if isinstance(idx, LoopIR.Point): + index_list.append( + self.make_expression(idx.pt, var_renaming) + ) + elif isinstance(idx, LoopIR.Interval): + index_list.append( + self.make_expression(idx.lo, var_renaming).add( + self.make_expression( + next(expr_idx_iter), var_renaming + ) + ) + ) + else: + assert False, "unexpected window access" + index_exprs = tuple(index_list) + else: + assert False, "unexpected buffer type" + if buf_name in var_renaming: + buf_name = var_renaming[buf_name] + if buf_name in self.unbound_buffers: + return Expression.unsolvable() + return Expression.from_function( + IndexTerm( + buf_name, + index_exprs, + lambda indices: self.register_buffer_index(indices, buf_name), + ) + ) elif isinstance(expr, LoopIR.Const): return Expression.from_constant(expr.val) elif isinstance(expr, LoopIR.USub): - return self.make_expression(expr.arg).negate() + return self.make_expression(expr.arg, var_renaming).negate() elif isinstance(expr, LoopIR.BinOp): # TODO: support mod and div using extra variables - lhs = self.make_expression(expr.lhs) - rhs = self.make_expression(expr.rhs) + lhs = self.make_expression(expr.lhs, var_renaming) + rhs = self.make_expression(expr.rhs, var_renaming) if expr.op == "+": return lhs.add(rhs) elif expr.op == "-": @@ -389,27 +587,42 @@ def make_expression(self, expr: Union[LoopIR.expr, Sym]) -> Expression: return lhs.multiply(rhs) elif expr.op in ["/", "%"]: div, rem = Sym("div"), Sym("rem") - self.hidden_vars.update((div, rem)) - self.extra_constraints.append( + visible_sym = rem if expr.op == "%" else div + self.extra_constraints[visible_sym] = ( Constraint( lhs.negate() .add(Expression.from_sym(rem)) .add(rhs.multiply(Expression.from_sym(div))), False, ) - ) - self.extra_constraints.append( - Constraint( - Expression.from_sym(rem) - .add(Expression.from_constant(1)) - .negate() - .add(rhs), - True, + .lift_to_disjoint_constraint() + .intersect( + Constraint( + Expression.from_sym(rem) + .add(Expression.from_constant(1)) + .negate() + .add(rhs), + True, + ).lift_to_disjoint_constraint() ) ) - return Expression.from_sym(rem if expr.op == "%" else div) + return Expression.from_sym(visible_sym) else: assert False, f"unsupported op in assertion: {expr.op}" + elif isinstance(expr, LoopIR.Extern): + extern_args = tuple( + self.make_expression(arg, var_renaming) for arg in expr.args + ) + extern: Extern = expr.f + extern_result_sym = Sym("ext") + try: + extern_constraint = extern.express_in_constraints( + extern_args, extern_result_sym + ) + self.extra_constraints[extern_result_sym] = extern_constraint + return Expression.from_sym(extern_result_sym) + except NotImplementedError: + return Expression.from_function(ExternTerm(extern, extern_args)) elif isinstance(expr, LoopIR.StrideExpr): if (expr.name, expr.dim) not in self.stride_dummies: new_sym = Sym("stride") @@ -431,31 +644,30 @@ def make_expression(self, expr: Union[LoopIR.expr, Sym]) -> Expression: assert False, f"unsupported expr" def make_constraint( - self, - expr: LoopIR.expr, + self, expr: LoopIR.expr, var_renaming: dict[Sym, Sym] ) -> DisjointConstraint: # expect that expr is bool type if isinstance(expr, LoopIR.BinOp): if expr.op == "and": lhs_constraints, rhs_constraints = self.make_constraint( - expr.lhs - ), self.make_constraint(expr.rhs) + expr.lhs, var_renaming + ), self.make_constraint(expr.rhs, var_renaming) return lhs_constraints.intersect(rhs_constraints) elif expr.op == "or": lhs_constraints, rhs_constraints = self.make_constraint( - expr.lhs - ), self.make_constraint(expr.rhs) + expr.lhs, var_renaming + ), self.make_constraint(expr.rhs, var_renaming) return lhs_constraints.union(rhs_constraints) elif expr.op == "==" and isinstance(expr.lhs.type, LoopIR.Bool): lhs_constraints, rhs_constraints = self.make_constraint( - expr.lhs - ), self.make_constraint(expr.rhs) + expr.lhs, var_renaming + ), self.make_constraint(expr.rhs, var_renaming) return ( lhs_constraints.invert().intersect(rhs_constraints.invert()) ).union(lhs_constraints.intersect(rhs_constraints)) else: return self.make_constraint_from_inequality( - expr.lhs, expr.rhs, expr.op + expr.lhs, expr.rhs, expr.op, var_renaming ).lift_to_disjoint_constraint() elif isinstance(expr, LoopIR.Read): assert len(expr.idx) == 0, "cannot index into boolean" @@ -465,7 +677,6 @@ def make_constraint( ).lift_to_disjoint_constraint() elif isinstance(expr, LoopIR.Const): return TRUE_CONSTRAINT if expr.val else FALSE_CONSTRAINT - elif isinstance(expr, LoopIR.ReadConfig): if (expr.config, expr.field) not in self.ctxt: field_type = expr.config.lookup_type(expr.field) @@ -485,10 +696,14 @@ def make_constraint( assert False, "only boolean expected" def make_constraint_from_inequality( - self, lhs: Union[LoopIR.expr, Sym], rhs: Union[LoopIR.expr, Sym], op: str + self, + lhs: Union[LoopIR.expr, Sym], + rhs: Union[LoopIR.expr, Sym], + op: str, + var_renaming: dict[Sym, Sym], ) -> Constraint: - lhs_expr = self.make_expression(lhs) - rhs_expr = self.make_expression(rhs) + lhs_expr = self.make_expression(lhs, var_renaming) + rhs_expr = self.make_expression(rhs, var_renaming) if op == "<": return Constraint( rhs_expr.add(lhs_expr.negate()).add(Expression.from_constant(-1)), True @@ -512,33 +727,48 @@ def _make_solution_from_assignments(self, assignments: dict[Sym, int]) -> Soluti result = sub.substitute(assignments).get_trivial_result() if result is not None: var_assignments[sym] = result + buffer_assignments = {} + for sym, assignment in assignments.items(): + if sym in self.buffer_syms: + buffer_assignments[self.buffer_syms[sym]] = assignment ctxt = {} for (config, field), sub in self.ctxt.items(): result = sub.substitute(assignments).get_trivial_result() if result is not None: ctxt[(config, field)] = result - return Solution(ctxt, var_assignments, assignments) + return Solution(ctxt, var_assignments, buffer_assignments, assignments) def _solve_for_assignments( self, all_constraints: tuple[Constraint, ...], bound: int ) -> Union[Literal["failed", "infeasible"], dict[Sym, int]]: - sym_universe = set() - for constraint in all_constraints: - sym_universe |= constraint.collect_syms() assignments = {} - while len(assignments) < len(sym_universe): + self.buffer_syms = {} + while True: linear_constraints: list[LinearConstraint] = [] linear_constraint_syms: set[Sym] = set() nonlinear_syms: set[Sym] = set() + nontrivial_constraint_exists = False for constraint in all_constraints: + if constraint.is_unsolvable(): + return "infeasible" linear_result = constraint.linearize(assignments) if linear_result is not None: - linear_constraints.append(linear_result) - linear_constraint_syms |= { - sym for sym in linear_result.coefficients.keys() - } - nonlinear_syms |= constraint.collect_nonlinear_syms() - nonlinear_syms -= assignments.keys() + trivial_result = linear_result.get_trivial_result() + if trivial_result == False: + return "infeasible" if len(assignments) == 0 else "failed" + elif trivial_result is None: + nontrivial_constraint_exists = True + linear_constraints.append(linear_result) + linear_constraint_syms |= { + sym for sym in linear_result.coefficients.keys() + } + else: + nontrivial_constraint_exists = True + nonlinear_syms |= constraint.substitute( + assignments + ).collect_nonlinear_syms() + if not nontrivial_constraint_exists: + break priority_syms = nonlinear_syms & linear_constraint_syms if len(priority_syms) == 0 and len(nonlinear_syms) != 0: chosen_sym = np.random.choice( @@ -663,31 +893,18 @@ def _solve_for_assignments( if len(nonlinear_syms) == 0: for sym in linear_constraint_syms: assignments[sym] = int(solution[sym_ordering[sym]]) - for sym in sym_universe - assignments.keys(): - assignments[sym] = np.random.randint(0, bound) else: chosen_sym = None if len(priority_syms) != 0: chosen_sym = np.random.choice( sorted(list(priority_syms), key=lambda sym: sym._id) ) - elif len(linear_constraint_syms) != 0: + else: + assert len(linear_constraint_syms) != 0 chosen_sym = np.random.choice( sorted(list(linear_constraint_syms), key=lambda sym: sym._id) ) - if chosen_sym is None: - free_syms = ( - sym_universe - - linear_constraint_syms - - assignments.keys() - - nonlinear_syms - ) - chosen_sym = np.random.choice( - sorted(list(free_syms), key=lambda sym: sym._id) - ) - assignments[chosen_sym] = np.random.randint(0, bound) - else: - assignments[chosen_sym] = int(solution[sym_ordering[chosen_sym]]) + assignments[chosen_sym] = int(solution[sym_ordering[chosen_sym]]) return assignments def solve_constraint( @@ -706,13 +923,38 @@ def solve_constraint( partial_solution.substitutions ) - clauses = list(disjoint_constraint.clauses) + clauses = list( + clause + for clause in disjoint_constraint.clauses + if all(not constraint.is_unsolvable() for constraint in clause.constraints) + ) for _ in range(search_limit): if len(clauses) == 0: return None chosen_clause = np.random.choice(clauses) assert isinstance(chosen_clause, ConstraintClause) - all_constraints = chosen_clause.constraints + tuple(self.extra_constraints) + chosen_clause_syms = chosen_clause.collect_syms() + chosen_extra_clauses: list[Constraint] = [] + failed_to_choose = False + for sym, extra_constraint in self.extra_constraints.items(): + if sym in chosen_clause_syms: + extra_constraint_clauses = list( + clause + for clause in extra_constraint.clauses + if all( + not constraint.is_unsolvable() + for constraint in clause.constraints + ) + ) + if len(extra_constraint_clauses) == 0: + failed_to_choose = True + break + chosen_extra_clause = np.random.choice(extra_constraint_clauses) + assert isinstance(chosen_extra_clause, ConstraintClause) + chosen_extra_clauses.extend(chosen_extra_clause.constraints) + if failed_to_choose: + continue + all_constraints = chosen_clause.constraints + tuple(chosen_extra_clauses) assignment_result = self._solve_for_assignments(all_constraints, bound) if assignment_result == "failed": continue @@ -729,7 +971,7 @@ def rename_sym_set( self, syms: frozenset[Sym], free_vars: frozenset[Sym] ) -> tuple[dict[Sym, Sym], dict[Sym, Sym]]: var_renaming = {} - sym_renaming = {sym: Sym(sym.name()) for sym in self.hidden_vars & syms} + sym_renaming = {} for var in free_vars: var_sub = self.var_subs[var] var_sub_syms = var_sub.collect_syms() @@ -738,13 +980,18 @@ def rename_sym_set( renamed_var = Sym(var.name()) var_renaming[var] = renamed_var self.var_subs[renamed_var] = var_sub.rename_syms(sym_renaming) - self.extra_constraints.extend( - tuple( - extra_constraint.rename_syms(sym_renaming) - for extra_constraint in self.extra_constraints - if len(extra_constraint.collect_syms() & sym_renaming.keys()) != 0 - ) - ) + new_extra_constraints = {} + for sym, extra_constraint in self.extra_constraints.items(): + if ( + len(extra_constraint.collect_syms() & sym_renaming.keys()) != 0 + and sym in syms + ): + new_sym = Sym(sym.name()) + new_extra_constraints[new_sym] = extra_constraint.rename_syms( + sym_renaming + ) + sym_renaming[sym] = new_sym + self.extra_constraints |= new_extra_constraints return ( sym_renaming, var_renaming, From d8fd8fd87e2d9371d1849941f4cd0c0d5ee16a21 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Mon, 19 May 2025 12:18:24 -0400 Subject: [PATCH 18/24] thesis done --- src/exo/API.py | 14 +- src/exo/API_scheduling.py | 191 +- src/exo/backend/LoopIR_transpiler.py | 5 +- src/exo/backend/coverage.py | 13 +- src/exo/frontend/typecheck.py | 1 + src/exo/libs/externs.py | 2 +- src/exo/rewrite/LoopIR_scheduling.py | 2653 +++++++++++++++++++------- src/exo/rewrite/chexo.py | 124 +- src/exo/rewrite/constraint_solver.py | 31 +- src/exo/stdlib/scheduling.py | 3 + 10 files changed, 2275 insertions(+), 762 deletions(-) diff --git a/src/exo/API.py b/src/exo/API.py index 84f195993..ac5b19503 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -50,7 +50,7 @@ def proc(f, _instr=None, _check_mode: Optional[CheckMode] = None) -> "Procedure" return Procedure(parser.result(), _check_mode=_check_mode) -def instr(c_instr, c_global=""): +def instr(c_instr, c_global="", check_mode=None): if not isinstance(c_instr, str): raise TypeError("@instr decorator must be @instr()") @@ -58,7 +58,7 @@ def inner(f): if not isinstance(f, types.FunctionType): raise TypeError("@instr decorator must be applied to a function") - return proc(f, _instr=(c_instr, c_global)) + return proc(f, _instr=(c_instr, c_global), _check_mode=check_mode) return inner @@ -179,7 +179,15 @@ def __init__( _mod_config = _mod_config or frozenset() - self._check_mode = DEFAULT_CHECK_MODE if _check_mode is None else _check_mode + self._check_mode = ( + ( + _provenance_eq_Procedure._check_mode + if _provenance_eq_Procedure is not None + else DEFAULT_CHECK_MODE + ) + if _check_mode is None + else _check_mode + ) if isinstance(proc, LoopIR.UAST.proc): proc = TypeChecker(proc, self._check_mode).get_loopir() if self._check_mode != "dynamic": diff --git a/src/exo/API_scheduling.py b/src/exo/API_scheduling.py index fcf2ca7b6..5db2d8845 100644 --- a/src/exo/API_scheduling.py +++ b/src/exo/API_scheduling.py @@ -720,6 +720,34 @@ def __call__(self, expr_str, all_args): return expr +class NewStmtA(ArgumentProcessor): + def __init__(self, cursor_arg, before=True): + self.cursor_arg = cursor_arg + self.before = before + + def _get_ctxt_stmt(self, all_args): + cursor = all_args[self.cursor_arg] + while isinstance(cursor, PC.ExprCursor): + cursor = cursor.parent() + + # if we don't have a gap cursor, convert to a gap cursor + if not isinstance(cursor, PC.GapCursor): + cursor = cursor.before() if self.before else cursor.after() + + # TODO: improve parse_fragment to just take gaps + return cursor.anchor()._impl._node + + def __call__(self, stmt_str, all_args): + if not isinstance(stmt_str, str): + self.err("expected a string") + + proc = all_args["proc"] + ctxt_stmt = self._get_ctxt_stmt(all_args) + + expr = parse_fragment(proc._loopir_proc, stmt_str, ctxt_stmt) + return expr + + # This is implemented as a workaround because the # current PAST parser and PAST IR don't support windowing # expressions. @@ -897,7 +925,7 @@ def reorder_stmts(proc, block_cursor): def parallelize_loop(proc, loop_cursor): loop = loop_cursor._impl - ir, fwd = scheduling.DoParallelizeLoop(loop) + ir, fwd = scheduling.DoParallelizeLoop(loop, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -977,7 +1005,7 @@ def rewrite_expr(proc, expr_cursor, new_expr): -> `s[ expr_cursor -> new_expr]` """ - ir, fwd = scheduling.DoRewriteExpr(expr_cursor._impl, new_expr) + ir, fwd = scheduling.DoRewriteExpr(expr_cursor._impl, new_expr, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1003,13 +1031,16 @@ def bind_expr(proc, expr_cursors, new_name): `a = b + 4.0` """ exprs = [ec._impl for ec in expr_cursors] - if any(not e._node.type.is_numeric() for e in exprs): + if ( + any(not e._node.type.is_numeric() for e in exprs) + and not proc._check_mode == "dynamic" + ): raise TypeError( "only numeric (not index or size) expressions " "can be bound by bind_expr()" ) - ir, fwd = scheduling.DoBindExpr(new_name, exprs) + ir, fwd = scheduling.DoBindExpr(new_name, exprs, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1202,10 +1233,14 @@ def bind_config(proc, var_cursor, config, field): e = var_cursor._impl._node cfg_f_type = config.lookup_type(field) - if not isinstance(e, LoopIR.Read): + if not isinstance(e, LoopIR.Read) and proc._check_mode != "dynamic": raise TypeError("expected a cursor to a single variable Read") - if not (e.type.is_real_scalar() and len(e.idx) == 0) and not e.type.is_bool(): + if ( + not (e.type.is_real_scalar() and len(e.idx) == 0) + and not e.type.is_bool() + and proc._check_mode != "dynamic" + ): raise TypeError( f"cannot bind non-real-scalar non-boolean value {e} to configuration states, since index and size expressions may depend on loop iteration" ) @@ -1234,10 +1269,41 @@ def delete_config(proc, stmt_cursor): rewrite: `s1 ; config.field = _ ; s3 -> s1 ; s3` """ - ir, fwd, cfg = scheduling.DoDeleteConfig(proc._root(), stmt_cursor._impl) + ir, fwd, cfg = scheduling.DoDeleteConfig( + proc._root(), stmt_cursor._impl, proc._check_mode + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd, _mod_config=cfg) +@sched_op([StmtCursorA]) +def delete_stmt(proc, stmt_cursor): + """ + delete a statement + + args: + stmt_cursor - cursor or pattern pointing at the statement to + be deleted + + rewrite: + `s1 ; s2 ; s3 -> s1 ; s3` + """ + ir, fwd = scheduling.DoDeleteStmt(proc._root(), stmt_cursor._impl, proc._check_mode) + return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) + + +@sched_op([GapCursorA, NewExprA("gap_cursor"), NewExprA("gap_cursor"), BoolA]) +def insert_mutate(proc, gap_cursor, buf_read, rhs, is_reduce): + if not (isinstance(buf_read, LoopIR.Read) and len(buf_read.idx) == 0): + raise SchedulingError() + new_stmt = (LoopIR.Reduce if is_reduce else LoopIR.Assign)( + buf_read.name, buf_read.type, rhs, buf_read.srcinfo + ) + ir, fwd = scheduling.DoInsertStmt( + proc._root(), gap_cursor._impl, new_stmt, proc._check_mode + ) + return Procedure(ir, __provenance_eq_Procedure=proc, _forward=fwd) + + @sched_op([GapCursorA, ConfigA, ConfigFieldA, NewExprA("gap_cursor")]) def write_config(proc, gap_cursor, config, field, rhs): """ @@ -1271,7 +1337,9 @@ def write_config(proc, gap_cursor, config, field, rhs): ) stmt = stmtc._impl - ir, fwd, cfg = scheduling.DoConfigWrite(stmt, config, field, rhs, before=before) + ir, fwd, cfg = scheduling.DoConfigWrite( + stmt, config, field, rhs, proc._check_mode, before=before + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd, _mod_config=cfg) @@ -1320,11 +1388,13 @@ def resize_dim(proc, buf_cursor, dim_idx, size, offset, fold: bool = False): assert isinstance(size, LoopIR.Const) and size.val > 0 size = size.val buf_s = buf_cursor._impl - ir, fwd = scheduling.DoFoldBuffer(buf_s, dim_idx, size) + ir, fwd = scheduling.DoFoldBuffer(buf_s, dim_idx, size, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) else: # Normal resize operation - ir, fwd = scheduling.DoResizeDim(stmt_c, dim_idx, size, offset) + ir, fwd = scheduling.DoResizeDim( + stmt_c, dim_idx, size, offset, proc._check_mode + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1354,7 +1424,7 @@ def expand_dim(proc, buf_cursor, alloc_dim, indexing_expr): provided indexing expression is checked to make sure it is in-bounds """ stmt_c = buf_cursor._impl - ir, fwd = scheduling.DoExpandDim(stmt_c, alloc_dim, indexing_expr) + ir, fwd = scheduling.DoExpandDim(stmt_c, alloc_dim, indexing_expr, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1410,7 +1480,7 @@ def divide_dim(proc, alloc_cursor, dim_idx, quotient): if not (0 <= dim_idx < len(stmt._node.type.shape())): raise ValueError(f"Cannot divide out-of-bounds dimension index {dim_idx}") - ir, fwd = scheduling.DoDivideDim(stmt, dim_idx, quotient) + ir, fwd = scheduling.DoDivideDim(stmt, dim_idx, quotient, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1584,7 +1654,7 @@ def reuse_buffer(proc, buf_cursor, replace_cursor): """ buf_s = buf_cursor._impl rep_s = replace_cursor._impl - ir, fwd = scheduling.DoReuseBuffer(buf_s, rep_s) + ir, fwd = scheduling.DoReuseBuffer(buf_s, rep_s, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1681,7 +1751,12 @@ def divide_with_recompute(proc, loop_cursor, outer_hi, outer_stride, new_iters): ` s[ i -> outer_stride * io + ii ]` """ ir, fwd = scheduling.DoDivideWithRecompute( - loop_cursor._impl, outer_hi, outer_stride, new_iters[0], new_iters[1] + loop_cursor._impl, + outer_hi, + outer_stride, + new_iters[0], + new_iters[1], + proc._check_mode, ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1738,12 +1813,74 @@ def divide_loop(proc, loop_cursor, div_const, new_iters, tail="guard", perfect=F quot=div_const, outer_iter=new_iters[0], inner_iter=new_iters[1], + check_mode=proc._check_mode, tail=tail, perfect=perfect, ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) +@sched_op( + [ + ForCursorA, + ConfigA, + ConfigFieldA("quot_config"), + ListA(NameA, length=2), + ] +) +def divide_loop_min(proc, loop_cursor, quot_config, quot_field, new_iters): + """ + Divide a loop into an outer and inner loop, where the inner loop + iterates over the range 0 to `div_const`. + + Old Name: In Halide and TVM, this was called "split" + + args: + loop_cursor - cursor pointing to the loop to split ; + can also be specified using the special shorthands + pattern: + or: # + div_const - integer > 1 specifying what to "divide by" + new_iters - list or tuple of two strings specifying the new + outer and inner iteration variable names + tail (opt) - specifies the strategy for handling the "remainder" + of the loop division (called the tail of the loop). + value can be "cut", "guard", or "cut_and_guard". + Default value: "guard" + perfect (opt) - Boolean (default False) that can be set to true + to assert that you know the remainder will always + be zero (i.e. there is no tail). You will get an + error if the compiler cannot verify this fact itself. + + rewrite: + divide(..., div_const=q, new_iters=['hi','lo'], tail='cut') + `for i in seq(0,e):` + ` s` + -> + `for hi in seq(0,e / q):` + ` for lo in seq(0, q):` + ` s[ i -> q*hi + lo ]` + `for lo in seq(0,e - q * (e / q)):` + ` s[ i -> q * (e / q) + lo ] + """ + + stmt = loop_cursor._impl + + ir, fwd = scheduling.DoDivideLoopMin( + stmt, + quot=LoopIR.ReadConfig( + quot_config, + quot_field, + quot_config.lookup_type(quot_field), + loop_cursor._impl._node.srcinfo, + ), + outer_iter=new_iters[0], + inner_iter=new_iters[1], + check_mode=proc._check_mode, + ) + return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) + + @sched_op([NestedForCursorA, NameA]) def mult_loops(proc, nested_loops, new_iter_name): """ @@ -1787,7 +1924,9 @@ def join_loops(proc, loop1_cursor, loop2_cursor): `for i in seq(lo, hi):` ` s` """ - ir, fwd = scheduling.DoJoinLoops(loop1_cursor._impl, loop2_cursor._impl) + ir, fwd = scheduling.DoJoinLoops( + loop1_cursor._impl, loop2_cursor._impl, proc._check_mode + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1815,7 +1954,7 @@ def cut_loop(proc, loop_cursor, cut_point): `for i in seq(cut, n):` ` s` """ - ir, fwd = scheduling.DoCutLoop(loop_cursor._impl, cut_point) + ir, fwd = scheduling.DoCutLoop(loop_cursor._impl, cut_point, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1838,7 +1977,7 @@ def shift_loop(proc, loop_cursor, new_lo): `for i in seq(new_lo, new_lo + n - m):` ` s(i + (m - new_lo))` """ - ir, fwd = scheduling.DoShiftLoop(loop_cursor._impl, new_lo) + ir, fwd = scheduling.DoShiftLoop(loop_cursor._impl, new_lo, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1873,7 +2012,7 @@ def reorder_loops(proc, nested_loops): if len(stmt_c.body()) != 1 or not isinstance(stmt_c.body()[0]._node, LoopIR.For): raise ValueError(f"expected loop directly inside of {stmt_c._node.iter} loop") - ir, fwd = scheduling.DoLiftScope(stmt_c.body()[0]) + ir, fwd = scheduling.DoLiftScope(stmt_c.body()[0], proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2066,7 +2205,7 @@ def fission(proc, gap_cursor, n_lifts=1, unsafe_disable_checks=False): ) ir, fwd = scheduling.DoFissionAfterSimple( - stmt._impl, n_lifts, unsafe_disable_checks + stmt._impl, n_lifts, unsafe_disable_checks, proc._check_mode ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2148,9 +2287,9 @@ def fuse(proc, stmt1, stmt2, unsafe_disable_check=False): s1 = stmt1._impl s2 = stmt2._impl if isinstance(stmt1, PC.IfCursor): - ir, fwd = scheduling.DoFuseIf(s1, s2) + ir, fwd = scheduling.DoFuseIf(s1, s2, proc._check_mode) else: - ir, fwd = scheduling.DoFuseLoop(s1, s2, unsafe_disable_check) + ir, fwd = scheduling.DoFuseLoop(s1, s2, proc._check_mode, unsafe_disable_check) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2170,7 +2309,9 @@ def remove_loop(proc, loop_cursor, unsafe_disable_check=False): -> `s` """ - ir, fwd = scheduling.DoRemoveLoop(loop_cursor._impl, unsafe_disable_check) + ir, fwd = scheduling.DoRemoveLoop( + loop_cursor._impl, unsafe_disable_check, proc._check_mode + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2205,7 +2346,7 @@ def add_loop( stmt_c = block_cursor[0]._impl ir, fwd = scheduling.DoAddLoop( - stmt_c, iter_name, hi_expr, guard, unsafe_disable_check + stmt_c, iter_name, hi_expr, guard, unsafe_disable_check, proc._check_mode ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2259,7 +2400,7 @@ def lift_scope(proc, scope_cursor): """ stmt_c = scope_cursor._impl - ir, fwd = scheduling.DoLiftScope(stmt_c) + ir, fwd = scheduling.DoLiftScope(stmt_c, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2281,7 +2422,7 @@ def eliminate_dead_code(proc, stmt_cursor): `s1` """ - ir, fwd = scheduling.DoEliminateDeadCode(stmt_cursor._impl) + ir, fwd = scheduling.DoEliminateDeadCode(stmt_cursor._impl, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) diff --git a/src/exo/backend/LoopIR_transpiler.py b/src/exo/backend/LoopIR_transpiler.py index c145ccd62..5544becb7 100644 --- a/src/exo/backend/LoopIR_transpiler.py +++ b/src/exo/backend/LoopIR_transpiler.py @@ -428,10 +428,13 @@ def access_tensor( f"let {repr(access_set_sym)}_pw=new Set();", f"let {repr(access_set_sym)}_cr=new Set();", f"let {repr(access_set_sym)}_cw=new Set();", - f"let {repr(self.coverage_sym)}=false;", ) ), ), + IndexedFiller( + self.parent_state.cov_placeholder, + f"let {repr(self.coverage_sym)}=false;", + ), IndexedFiller( access_placeholder, "".join( diff --git a/src/exo/backend/coverage.py b/src/exo/backend/coverage.py index 2483c1176..5ab8f035f 100644 --- a/src/exo/backend/coverage.py +++ b/src/exo/backend/coverage.py @@ -111,14 +111,13 @@ def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: new_solution = state.cm.solve_constraint( new_constraint, bound=state.bound, search_limit=state.search_limit ) - if ( - new_solution is None and state.is_base_constraint - ) or new_solution is not None: + if new_solution is None and state.is_base_constraint: + (self.true_child if uncovered_path else self.false_child).mark_visited() + if new_solution is not None: if uncovered_path: self.true_child.visited = True else: self.false_child.visited = True - if new_solution is not None: return state.update_solution(new_constraint, new_solution) elif self.true_child.visited and self.false_child.visited: return self.false_child.solve_coverage( @@ -182,6 +181,12 @@ def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: current_state = branch.solve_coverage(current_state) return current_state + def mark_visited(self): + self.visited = True + for branch in self.branches: + branch.true_child.mark_visited() + branch.false_child.mark_visited() + @dataclass class MemoryAccess: diff --git a/src/exo/frontend/typecheck.py b/src/exo/frontend/typecheck.py index 6bbfa8bd4..1518a88cc 100644 --- a/src/exo/frontend/typecheck.py +++ b/src/exo/frontend/typecheck.py @@ -135,6 +135,7 @@ def __init__(self, proc, check_mode: CheckMode): self.err( proc, f"expected writes to configuration {name[0].name()}.{name[1]} does not depend on loop iterations", + allowed_in_chexo=True, ) instr = proc.instr diff --git a/src/exo/libs/externs.py b/src/exo/libs/externs.py index 028ca4d44..cc9bbf78d 100644 --- a/src/exo/libs/externs.py +++ b/src/exo/libs/externs.py @@ -369,4 +369,4 @@ def express_in_constraints( ) -intmin = _IntMin +intmin = _IntMin() diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index bb4d0bffc..f5b5650f0 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -1,7 +1,7 @@ import re from collections import ChainMap import traceback -from typing import Any, Callable, List, Literal, Tuple, Optional +from typing import Any, Callable, Generator, List, Literal, Tuple, Optional, Union from ..core.LoopIR import ( LoopIR, @@ -36,6 +36,7 @@ from .chexo import fuzz, fuzz_reorder_stmts from .range_analysis import IndexRangeEnvironment, IndexRange, index_range_analysis +from ..core.internal_cursors import Block, Node from ..core.prelude import * from ..core.proc_eqv import get_strictest_eqv_proc @@ -44,6 +45,7 @@ from ..frontend.pattern_match import match_pattern from ..core.memory import DRAM from ..frontend.typecheck import check_call_types, CheckMode +from ..libs.externs import intmin from functools import partial @@ -215,10 +217,64 @@ def _replace_pats(ir, fwd, c, pat, repl, only_replace_attrs=True, use_sym_id=Tru return ir, _compose(cur_fwd, fwd) +def find_all_sym_reads(c: Union[Block, Node], sym: Sym) -> Generator[Node, None, None]: + if isinstance(c, Block): + for node in c: + yield from find_all_sym_reads(node, sym) + else: + if isinstance(c._node, (LoopIR.Assign, LoopIR.Reduce)): + yield from find_all_sym_reads(c._child_block("idx"), sym) + yield from find_all_sym_reads(c._child_node("rhs"), sym) + elif isinstance(c._node, LoopIR.WriteConfig): + yield from find_all_sym_reads(c._child_node("rhs"), sym) + elif isinstance(c._node, LoopIR.If): + yield from find_all_sym_reads(c._child_node("cond"), sym) + yield from find_all_sym_reads(c._child_block("body"), sym) + yield from find_all_sym_reads(c._child_block("orelse"), sym) + elif isinstance(c._node, LoopIR.For): + yield from find_all_sym_reads(c._child_node("lo"), sym) + yield from find_all_sym_reads(c._child_node("hi"), sym) + yield from find_all_sym_reads(c._child_block("body"), sym) + elif isinstance(c._node, LoopIR.Alloc): + yield from find_all_sym_reads(c._child_node("type"), sym) + elif isinstance(c._node, LoopIR.Free): + yield from find_all_sym_reads(c._child_node("type"), sym) + elif isinstance(c._node, LoopIR.Call): + yield from find_all_sym_reads(c._child_block("args"), sym) + elif isinstance(c._node, LoopIR.WindowStmt): + yield from find_all_sym_reads(c._child_node("rhs"), sym) + elif isinstance(c._node, LoopIR.expr): + yield from find_all_sym_reads(c._child_node("type"), sym) + if isinstance(c._node, LoopIR.Read): + yield from find_all_sym_reads(c._child_block("idx"), sym) + if c._node.name == sym: + yield c + elif isinstance(c._node, LoopIR.USub): + yield from find_all_sym_reads(c._child_node("arg"), sym) + elif isinstance(c._node, LoopIR.BinOp): + yield from find_all_sym_reads(c._child_node("lhs"), sym) + yield from find_all_sym_reads(c._child_node("rhs"), sym) + elif isinstance(c._node, LoopIR.Extern): + yield from find_all_sym_reads(c._child_block("args"), sym) + elif isinstance(c._node, LoopIR.WindowExpr): + yield from find_all_sym_reads(c._child_block("idx"), sym) + elif isinstance(c._node, LoopIR.Tensor): + yield from find_all_sym_reads(c._child_block("hi"), sym) + elif isinstance(c._node, LoopIR.WindowType): + yield from find_all_sym_reads(c._child_node("src_type"), sym) + yield from find_all_sym_reads(c._child_node("as_tensor"), sym) + yield from find_all_sym_reads(c._child_block("idx"), sym) + elif isinstance(c._node, LoopIR.Interval): + yield from find_all_sym_reads(c._child_node("lo"), sym) + yield from find_all_sym_reads(c._child_node("hi"), sym) + elif isinstance(c._node, LoopIR.Point): + yield from find_all_sym_reads(c._child_node("pt"), sym) + + def _replace_reads(ir, fwd, c, sym, repl, only_replace_attrs=True): c = fwd(c) todos = [] - for rd in match_pattern(c, f"{repr(sym)}[_]", use_sym_id=True): + for rd in find_all_sym_reads(c, sym): # Need [_] to pattern match against window expressiontatic if c_repl := repl(rd): todos.append((rd, c_repl)) @@ -419,97 +475,182 @@ def DoReorderStmt(f_cursor, s_cursor, check_mode: CheckMode): return ir, fwd -def DoParallelizeLoop(loop_cursor): - return loop_cursor._child_node("loop_mode")._replace(LoopIR.Par()) +def DoParallelizeLoop(loop_cursor, check_mode: CheckMode): + ir, fwd = loop_cursor._child_node("loop_mode")._replace(LoopIR.Par()) + def static_check(): + pass -def DoJoinLoops(loop1_c, loop2_c): - if loop1_c.next() != loop2_c: - raise SchedulingError("expected the second loop to be directly after the first") + def dynamic_check(): + fuzz(loop_cursor, fwd) - loop1 = loop1_c._node - loop2 = loop2_c._node + do_check(static_check, dynamic_check, check_mode) + return ir, fwd - try: - Check_ExprEqvInContext(loop1_c.get_root(), loop1.hi, [loop1], loop2.lo, [loop2]) - except Exception as e: - raise SchedulingError( - f"expected the first loop upper bound {loop1.hi} to be the same as the second loop lower bound {loop2.lo}" - ) - compare_ir = LoopIR_Compare() - if not compare_ir.match_stmts(loop1.body, loop2.body): - raise SchedulingError("expected the two loops to have identical bodies") +def DoJoinLoops(loop1_c, loop2_c, check_mode: CheckMode): + def static_check(): + if loop1_c.next() != loop2_c: + raise SchedulingError( + "expected the second loop to be directly after the first" + ) + + loop1 = loop1_c._node + loop2 = loop2_c._node + + try: + Check_ExprEqvInContext( + loop1_c.get_root(), loop1.hi, [loop1], loop2.lo, [loop2] + ) + except Exception as e: + raise SchedulingError( + f"expected the first loop upper bound {loop1.hi} to be the same as the second loop lower bound {loop2.lo}" + ) + + compare_ir = LoopIR_Compare() + if not compare_ir.match_stmts(loop1.body, loop2.body): + raise SchedulingError("expected the two loops to have identical bodies") - ir, fwd = loop1_c._child_node("hi")._replace(loop2.hi) - ir, fwd_del = fwd(loop2_c)._delete() + ir, fwd = loop1_c._child_node("hi")._replace(loop2.hi) + ir, fwd_del = fwd(loop2_c)._delete() - return ir, _compose(fwd_del, fwd) + return ir, _compose(fwd_del, fwd) + def dynamic_check(): + if loop1_c.next() != loop2_c: + raise SchedulingError( + "expected the second loop to be directly after the first" + ) -def DoCutLoop(loop_c, cut_point): - s = loop_c._node + loop1 = loop1_c._node + loop2 = loop2_c._node - assert isinstance(s, LoopIR.For) + compare_ir = LoopIR_Compare() + if not compare_ir.match_stmts(loop1.body, loop2.body): + raise SchedulingError("expected the two loops to have identical bodies") - ir = loop_c.get_root() + ir, fwd = loop1_c._child_node("hi")._replace(loop2.hi) + ir, fwd_del = fwd(loop2_c)._delete() + fuzz(loop1_c.as_block().expand(delta_lo=0, delta_hi=1), fwd) - try: - Check_CompareExprs(ir, [s], cut_point, ">=", s.lo) - except SchedulingError: - raise SchedulingError(f"Expected `lo` <= `cut_point`") + return ir, _compose(fwd_del, fwd) - try: - Check_CompareExprs(ir, [s], s.hi, ">=", cut_point) - except SchedulingError: - raise SchedulingError(f"Expected `cut_point` <= `hi`") + return do_check(static_check, dynamic_check, check_mode) - ir, fwd1 = loop_c._child_node("hi")._replace(cut_point) - loop2 = Alpha_Rename([s.update(lo=cut_point)]).result()[0] - ir, fwd2 = fwd1(loop_c).after()._insert([loop2]) - fwd = _compose(fwd2, fwd1) - return ir, fwd +def DoCutLoop(loop_c, cut_point, check_mode: CheckMode): + def static_check(): + s = loop_c._node + assert isinstance(s, LoopIR.For) -def DoShiftLoop(loop_c, new_lo): - s = loop_c._node + ir = loop_c.get_root() - assert isinstance(s, LoopIR.For) + try: + Check_CompareExprs(ir, [s], cut_point, ">=", s.lo) + except SchedulingError: + raise SchedulingError(f"Expected `lo` <= `cut_point`") - try: - Check_IsNonNegativeExpr( - loop_c.get_root(), - [s], - new_lo, + try: + Check_CompareExprs(ir, [s], s.hi, ">=", cut_point) + except SchedulingError: + raise SchedulingError(f"Expected `cut_point` <= `hi`") + + ir, fwd1 = loop_c._child_node("hi")._replace(cut_point) + loop2 = Alpha_Rename([s.update(lo=cut_point)]).result()[0] + ir, fwd2 = fwd1(loop_c).after()._insert([loop2]) + fwd = _compose(fwd2, fwd1) + + return ir, fwd + + def dynamic_check(): + s = loop_c._node + + assert isinstance(s, LoopIR.For) + + ir = loop_c.get_root() + + ir, fwd1 = loop_c._child_node("hi")._replace(cut_point) + loop2 = Alpha_Rename([s.update(lo=cut_point)]).result()[0] + ir, fwd2 = fwd1(loop_c).after()._insert([loop2]) + fwd = _compose(fwd2, fwd1) + fuzz(loop_c.parent(), fwd) + + return ir, fwd + + return do_check(static_check, dynamic_check, check_mode) + + +def DoShiftLoop(loop_c, new_lo, check_mode: CheckMode): + def static_check(): + s = loop_c._node + + assert isinstance(s, LoopIR.For) + + try: + Check_IsNonNegativeExpr( + loop_c.get_root(), + [s], + new_lo, + ) + except SchedulingError: + raise SchedulingError(f"Expected 0 <= `new_lo`") + + loop_length = LoopIR.BinOp("-", s.hi, s.lo, T.index, s.srcinfo) + new_hi = LoopIR.BinOp("+", new_lo, loop_length, T.index, s.srcinfo) + + ir, fwd1 = loop_c._child_node("lo")._replace(new_lo) + ir, fwd2 = fwd1(loop_c)._child_node("hi")._replace(new_hi) + fwd12 = _compose(fwd2, fwd1) + + # all uses of the loop iteration in the second body need + # to be offset by (`lo` - `new_lo``) + loop_iter = s.iter + iter_node = LoopIR.Read(loop_iter, [], T.index, s.srcinfo) + iter_offset = LoopIR.BinOp("-", s.lo, new_lo, T.index, s.srcinfo) + new_iter = LoopIR.BinOp("+", iter_node, iter_offset, T.index, s.srcinfo) + + ir, fwd = _replace_reads( + ir, + fwd12, + loop_c, + loop_iter, + lambda _: new_iter, + only_replace_attrs=False, ) - except SchedulingError: - raise SchedulingError(f"Expected 0 <= `new_lo`") + return ir, fwd - loop_length = LoopIR.BinOp("-", s.hi, s.lo, T.index, s.srcinfo) - new_hi = LoopIR.BinOp("+", new_lo, loop_length, T.index, s.srcinfo) + def dynamic_check(): + s = loop_c._node - ir, fwd1 = loop_c._child_node("lo")._replace(new_lo) - ir, fwd2 = fwd1(loop_c)._child_node("hi")._replace(new_hi) - fwd12 = _compose(fwd2, fwd1) + assert isinstance(s, LoopIR.For) - # all uses of the loop iteration in the second body need - # to be offset by (`lo` - `new_lo``) - loop_iter = s.iter - iter_node = LoopIR.Read(loop_iter, [], T.index, s.srcinfo) - iter_offset = LoopIR.BinOp("-", s.lo, new_lo, T.index, s.srcinfo) - new_iter = LoopIR.BinOp("+", iter_node, iter_offset, T.index, s.srcinfo) + loop_length = LoopIR.BinOp("-", s.hi, s.lo, T.index, s.srcinfo) + new_hi = LoopIR.BinOp("+", new_lo, loop_length, T.index, s.srcinfo) - ir, fwd = _replace_reads( - ir, - fwd12, - loop_c, - loop_iter, - lambda _: new_iter, - only_replace_attrs=False, - ) + ir, fwd1 = loop_c._child_node("lo")._replace(new_lo) + ir, fwd2 = fwd1(loop_c)._child_node("hi")._replace(new_hi) + fwd12 = _compose(fwd2, fwd1) - return ir, fwd + # all uses of the loop iteration in the second body need + # to be offset by (`lo` - `new_lo``) + loop_iter = s.iter + iter_node = LoopIR.Read(loop_iter, [], T.index, s.srcinfo) + iter_offset = LoopIR.BinOp("-", s.lo, new_lo, T.index, s.srcinfo) + new_iter = LoopIR.BinOp("+", iter_node, iter_offset, T.index, s.srcinfo) + + ir, fwd = _replace_reads( + ir, + fwd12, + loop_c, + loop_iter, + lambda _: new_iter, + only_replace_attrs=False, + ) + fuzz(loop_c, fwd) + return ir, fwd + + return do_check(static_check, dynamic_check, check_mode) def DoProductLoop(outer_loop_c, new_name): @@ -700,7 +841,12 @@ def mk_inline_expr(e): def DoDivideWithRecompute( - loop_cursor, outer_hi, outer_stride: int, iter_o: str, iter_i: str + loop_cursor, + outer_hi, + outer_stride: int, + iter_o: str, + iter_i: str, + check_mode: CheckMode, ): proc = loop_cursor.get_root() loop = loop_cursor._node @@ -708,84 +854,422 @@ def DoDivideWithRecompute( assert isinstance(loop, LoopIR.For) assert isinstance(outer_hi, LoopIR.expr) - Check_IsIdempotent(proc, loop.body) - def rd(i): - return LoopIR.Read(i, [], T.index, srcinfo) + def static_check(): + Check_IsIdempotent(proc, loop.body) - def cnst(intval): - return LoopIR.Const(intval, T.int, srcinfo) + def rd(i): + return LoopIR.Read(i, [], T.index, srcinfo) - def szop(op, lhs, rhs): - return LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) + def cnst(intval): + return LoopIR.Const(intval, T.int, srcinfo) - sym_o = Sym(iter_o) - sym_i = Sym(iter_i) - x = cnst(outer_stride) + def szop(op, lhs, rhs): + return LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) - if ( - isinstance(outer_hi, LoopIR.BinOp) - and outer_hi.op == "/" - and isinstance(outer_hi.rhs, LoopIR.Const) - and outer_hi.rhs.val == outer_stride - ): - N_before_recompute = szop("-", outer_hi.lhs, szop("%", outer_hi.lhs, x)) - else: - N_before_recompute = szop("*", outer_hi, x) + sym_o = Sym(iter_o) + sym_i = Sym(iter_i) + x = cnst(outer_stride) - N_recompute = LoopIR.BinOp("-", loop.hi, N_before_recompute, T.index, srcinfo) - try: - Check_IsNonNegativeExpr(proc, [loop], N_recompute) - except SchedulingError: - raise SchedulingError(f"outer_hi * outer_stride exceeds loop's hi {loop.hi}") + if ( + isinstance(outer_hi, LoopIR.BinOp) + and outer_hi.op == "/" + and isinstance(outer_hi.rhs, LoopIR.Const) + and outer_hi.rhs.val == outer_stride + ): + N_before_recompute = szop("-", outer_hi.lhs, szop("%", outer_hi.lhs, x)) + else: + N_before_recompute = szop("*", outer_hi, x) - hi_o = outer_hi - hi_i = szop("+", x, N_recompute) + N_recompute = LoopIR.BinOp("-", loop.hi, N_before_recompute, T.index, srcinfo) + try: + Check_IsNonNegativeExpr(proc, [loop], N_recompute) + except SchedulingError: + raise SchedulingError( + f"outer_hi * outer_stride exceeds loop's hi {loop.hi}" + ) - # turn current loop into outer loop - ir, fwd = loop_cursor._child_node("iter")._replace(sym_o) - ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(hi_o) - fwd = _compose(fwd_repl, fwd) + hi_o = outer_hi + hi_i = szop("+", x, N_recompute) - # wrap body in inner loop - def inner_wrapper(body): - return LoopIR.For( - sym_i, - LoopIR.Const(0, T.index, srcinfo), - hi_i, - body, - LoopIR.Seq(), - srcinfo, + # turn current loop into outer loop + ir, fwd = loop_cursor._child_node("iter")._replace(sym_o) + ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(hi_o) + fwd = _compose(fwd_repl, fwd) + + # wrap body in inner loop + def inner_wrapper(body): + return LoopIR.For( + sym_i, + LoopIR.Const(0, T.index, srcinfo), + hi_i, + body, + LoopIR.Seq(), + srcinfo, + ) + + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + + # replace the iteration variable in the body + def mk_iter(_): + return szop("+", szop("*", rd(sym_o), x), rd(sym_i)) + + ir, fwd = _replace_reads( + ir, + fwd, + loop_cursor, + loop.iter, + mk_iter, + only_replace_attrs=False, ) - ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) + return ir, fwd - # replace the iteration variable in the body - def mk_iter(_): - return szop("+", szop("*", rd(sym_o), x), rd(sym_i)) + def dynamic_check(): + def rd(i): + return LoopIR.Read(i, [], T.index, srcinfo) - ir, fwd = _replace_reads( - ir, - fwd, - loop_cursor, - loop.iter, - mk_iter, - only_replace_attrs=False, - ) + def cnst(intval): + return LoopIR.Const(intval, T.int, srcinfo) - return ir, fwd + def szop(op, lhs, rhs): + return LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) + + sym_o = Sym(iter_o) + sym_i = Sym(iter_i) + x = cnst(outer_stride) + + if ( + isinstance(outer_hi, LoopIR.BinOp) + and outer_hi.op == "/" + and isinstance(outer_hi.rhs, LoopIR.Const) + and outer_hi.rhs.val == outer_stride + ): + N_before_recompute = szop("-", outer_hi.lhs, szop("%", outer_hi.lhs, x)) + else: + N_before_recompute = szop("*", outer_hi, x) + + N_recompute = LoopIR.BinOp("-", loop.hi, N_before_recompute, T.index, srcinfo) + + hi_o = outer_hi + hi_i = szop("+", x, N_recompute) + + # turn current loop into outer loop + ir, fwd = loop_cursor._child_node("iter")._replace(sym_o) + ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(hi_o) + fwd = _compose(fwd_repl, fwd) + + # wrap body in inner loop + def inner_wrapper(body): + return LoopIR.For( + sym_i, + LoopIR.Const(0, T.index, srcinfo), + hi_i, + body, + LoopIR.Seq(), + srcinfo, + ) + + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + + # replace the iteration variable in the body + def mk_iter(_): + return szop("+", szop("*", rd(sym_o), x), rd(sym_i)) + + ir, fwd = _replace_reads( + ir, + fwd, + loop_cursor, + loop.iter, + mk_iter, + only_replace_attrs=False, + ) + fuzz(loop_cursor.parent(), fwd) + + return ir, fwd + + return do_check(static_check, dynamic_check, check_mode) def DoDivideLoop( - loop_cursor, quot, outer_iter, inner_iter, tail="guard", perfect=False + loop_cursor, + quot, + outer_iter, + inner_iter, + check_mode: CheckMode, + tail="guard", + perfect=False, +): + def static_check(): + loop = loop_cursor._node + N = loop.hi + outer_i = Sym(outer_iter) + inner_i = Sym(inner_iter) + srcinfo = loop.srcinfo + tail_strategy = "perfect" if perfect else tail + + if not is_const_zero(loop.lo): + raise SchedulingError( + f"expected the lower bound of the loop to be zero, got {loop.lo}." + ) + + def substitute(srcinfo): + cnst = lambda x: LoopIR.Const(x, T.int, srcinfo) + rd = lambda x: LoopIR.Read(x, [], T.index, srcinfo) + op = lambda op, lhs, rhs: LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) + + return op("+", op("*", cnst(quot), rd(outer_i)), rd(inner_i)) + + # short-hands for sanity + def boolop(op, lhs, rhs, typ): + return LoopIR.BinOp(op, lhs, rhs, typ, srcinfo) + + def szop(op, lhs, rhs): + return LoopIR.BinOp(op, lhs, rhs, lhs.type, srcinfo) + + def cnst(intval): + return LoopIR.Const(intval, T.int, srcinfo) + + def rd(i): + return LoopIR.Read(i, [], T.index, srcinfo) + + def ceildiv(lhs, rhs): + assert isinstance(rhs, LoopIR.Const) and rhs.val > 0 + rhs_1 = cnst(rhs.val - 1) + return szop("/", szop("+", lhs, rhs_1), rhs) + + # determine hi and lo loop bounds + inner_hi = cnst(quot) + if tail_strategy in ["guard"]: + outer_hi = ceildiv(N, inner_hi) + elif tail_strategy in ["cut", "cut_and_guard"]: + outer_hi = szop("/", N, inner_hi) # floor div + elif tail_strategy == "perfect": + ir = loop_cursor.get_root() + loop = loop_cursor._node + Check_IsDivisible(ir, [loop], N, quot) + outer_hi = divide_expr(N, quot) + else: + assert False, f"bad tail strategy: {tail_strategy}" + + # turn current loop into outer loop + ir, fwd = loop_cursor._child_node("iter")._replace(outer_i) + ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(outer_hi) + fwd = _compose(fwd_repl, fwd) + + # wrap body in a guard + if tail_strategy == "guard": + idx_sub = substitute(srcinfo) + + def guard_wrapper(body): + cond = boolop("<", idx_sub, N, T.bool) + return LoopIR.If(cond, body, [], srcinfo) + + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(guard_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + + # wrap body in inner loop + def inner_wrapper(body): + return LoopIR.For( + inner_i, + LoopIR.Const(0, T.index, srcinfo), + inner_hi, + body, + loop.loop_mode, + srcinfo, + ) + + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + + # replace the iteration variable in the body + def mk_main_iter(c): + return substitute(c._node.srcinfo) + + ir, fwd = _replace_reads( + ir, + fwd, + loop_cursor, + loop.iter, + mk_main_iter, + only_replace_attrs=False, + ) + + # add the tail case + if tail_strategy in ["cut", "cut_and_guard"]: + cut_i = Sym(inner_iter) + Ntail = szop("%", N, inner_hi) + + # in the tail loop we want the iteration variable to + # be mapped instead to (Ncut*Q + cut_i) + cut_tail_sub = szop("+", rd(cut_i), szop("*", outer_hi, inner_hi)) + + cut_body = Alpha_Rename(loop.body).result() + env = {loop.iter: cut_tail_sub} + cut_body = SubstArgs(cut_body, env).result() + + cut_s = LoopIR.For( + cut_i, + LoopIR.Const(0, T.index, srcinfo), + Ntail, + cut_body, + loop.loop_mode, + srcinfo, + ) + if tail_strategy == "cut_and_guard": + cond = boolop(">", Ntail, LoopIR.Const(0, T.int, srcinfo), T.bool) + cut_s = LoopIR.If(cond, [cut_s], [], srcinfo) + + ir, fwd_ins = fwd(loop_cursor).after()._insert([cut_s]) + fwd = _compose(fwd_ins, fwd) + + return ir, fwd + + def dynamic_check(): + loop = loop_cursor._node + N = loop.hi + outer_i = Sym(outer_iter) + inner_i = Sym(inner_iter) + srcinfo = loop.srcinfo + tail_strategy = "perfect" if perfect else tail + + if not is_const_zero(loop.lo): + raise SchedulingError( + f"expected the lower bound of the loop to be zero, got {loop.lo}." + ) + + def substitute(srcinfo): + cnst = lambda x: LoopIR.Const(x, T.int, srcinfo) + rd = lambda x: LoopIR.Read(x, [], T.index, srcinfo) + op = lambda op, lhs, rhs: LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) + + return op("+", op("*", cnst(quot), rd(outer_i)), rd(inner_i)) + + # short-hands for sanity + def boolop(op, lhs, rhs, typ): + return LoopIR.BinOp(op, lhs, rhs, typ, srcinfo) + + def szop(op, lhs, rhs): + return LoopIR.BinOp(op, lhs, rhs, lhs.type, srcinfo) + + def cnst(intval): + return LoopIR.Const(intval, T.int, srcinfo) + + def rd(i): + return LoopIR.Read(i, [], T.index, srcinfo) + + def ceildiv(lhs, rhs): + assert isinstance(rhs, LoopIR.Const) and rhs.val > 0 + rhs_1 = cnst(rhs.val - 1) + return szop("/", szop("+", lhs, rhs_1), rhs) + + # determine hi and lo loop bounds + inner_hi = cnst(quot) + if tail_strategy in ["guard"]: + outer_hi = ceildiv(N, inner_hi) + elif tail_strategy in ["cut", "cut_and_guard"]: + outer_hi = szop("/", N, inner_hi) # floor div + elif tail_strategy == "perfect": + ir = loop_cursor.get_root() + loop = loop_cursor._node + outer_hi = divide_expr(N, quot) + else: + assert False, f"bad tail strategy: {tail_strategy}" + + # turn current loop into outer loop + ir, fwd = loop_cursor._child_node("iter")._replace(outer_i) + ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(outer_hi) + fwd = _compose(fwd_repl, fwd) + + # wrap body in a guard + if tail_strategy == "guard": + idx_sub = substitute(srcinfo) + + def guard_wrapper(body): + cond = boolop("<", idx_sub, N, T.bool) + return LoopIR.If(cond, body, [], srcinfo) + + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(guard_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + + # wrap body in inner loop + def inner_wrapper(body): + return LoopIR.For( + inner_i, + LoopIR.Const(0, T.index, srcinfo), + inner_hi, + body, + loop.loop_mode, + srcinfo, + ) + + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + + # replace the iteration variable in the body + def mk_main_iter(c): + return substitute(c._node.srcinfo) + + ir, fwd = _replace_reads( + ir, + fwd, + loop_cursor, + loop.iter, + mk_main_iter, + only_replace_attrs=False, + ) + + # add the tail case + if tail_strategy in ["cut", "cut_and_guard"]: + cut_i = Sym(inner_iter) + Ntail = szop("%", N, inner_hi) + + # in the tail loop we want the iteration variable to + # be mapped instead to (Ncut*Q + cut_i) + cut_tail_sub = szop("+", rd(cut_i), szop("*", outer_hi, inner_hi)) + + cut_body = Alpha_Rename(loop.body).result() + env = {loop.iter: cut_tail_sub} + cut_body = SubstArgs(cut_body, env).result() + + cut_s = LoopIR.For( + cut_i, + LoopIR.Const(0, T.index, srcinfo), + Ntail, + cut_body, + loop.loop_mode, + srcinfo, + ) + if tail_strategy == "cut_and_guard": + cond = boolop(">", Ntail, LoopIR.Const(0, T.int, srcinfo), T.bool) + cut_s = LoopIR.If(cond, [cut_s], [], srcinfo) + + ir, fwd_ins = fwd(loop_cursor).after()._insert([cut_s]) + fwd = _compose(fwd_ins, fwd) + + if tail_strategy == "perfect": + fuzz(loop_cursor.parent(), fwd) + return ir, fwd + + return do_check(static_check, dynamic_check, check_mode) + + +def DoDivideLoopMin( + loop_cursor, + quot, + outer_iter, + inner_iter, + check_mode: CheckMode, ): + if check_mode != "dynamic": + raise SchedulingError("cannot use min tail strategy without chexo") loop = loop_cursor._node N = loop.hi outer_i = Sym(outer_iter) inner_i = Sym(inner_iter) srcinfo = loop.srcinfo - tail_strategy = "perfect" if perfect else tail if not is_const_zero(loop.lo): raise SchedulingError( @@ -793,72 +1277,68 @@ def DoDivideLoop( ) def substitute(srcinfo): - cnst = lambda x: LoopIR.Const(x, T.int, srcinfo) rd = lambda x: LoopIR.Read(x, [], T.index, srcinfo) op = lambda op, lhs, rhs: LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) - return op("+", op("*", cnst(quot), rd(outer_i)), rd(inner_i)) + return op("+", op("*", quot, rd(outer_i)), rd(inner_i)) # short-hands for sanity def boolop(op, lhs, rhs, typ): return LoopIR.BinOp(op, lhs, rhs, typ, srcinfo) - def szop(op, lhs, rhs): - return LoopIR.BinOp(op, lhs, rhs, lhs.type, srcinfo) - def cnst(intval): return LoopIR.Const(intval, T.int, srcinfo) + def szop(op, lhs, rhs): + return LoopIR.BinOp(op, lhs, rhs, lhs.type, srcinfo) + def rd(i): return LoopIR.Read(i, [], T.index, srcinfo) def ceildiv(lhs, rhs): - assert isinstance(rhs, LoopIR.Const) and rhs.val > 0 - rhs_1 = cnst(rhs.val - 1) + rhs_1 = LoopIR.BinOp("-", rhs, cnst(1), rhs.type, srcinfo) return szop("/", szop("+", lhs, rhs_1), rhs) # determine hi and lo loop bounds - inner_hi = cnst(quot) - if tail_strategy == "guard": - outer_hi = ceildiv(N, inner_hi) - elif tail_strategy in ["cut", "cut_and_guard"]: - outer_hi = szop("/", N, inner_hi) # floor div - elif tail_strategy == "perfect": - ir = loop_cursor.get_root() - loop = loop_cursor._node - Check_IsDivisible(ir, [loop], N, quot) - outer_hi = divide_expr(N, quot) - else: - assert False, f"bad tail strategy: {tail_strategy}" + inner_hi = quot + outer_hi = ceildiv(N, inner_hi) # turn current loop into outer loop ir, fwd = loop_cursor._child_node("iter")._replace(outer_i) ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(outer_hi) fwd = _compose(fwd_repl, fwd) - # wrap body in a guard - if tail_strategy == "guard": - idx_sub = substitute(srcinfo) - - def guard_wrapper(body): - cond = boolop("<", idx_sub, N, T.bool) - return LoopIR.If(cond, body, [], srcinfo) - - ir, fwd_wrap = fwd(loop_cursor).body()._wrap(guard_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - - # wrap body in inner loop - def inner_wrapper(body): + def inner_wrapper_min(body): return LoopIR.For( inner_i, LoopIR.Const(0, T.index, srcinfo), - inner_hi, + LoopIR.Extern( + intmin, + [ + inner_hi, + LoopIR.BinOp( + "-", + N, + LoopIR.BinOp( + "*", + quot, + LoopIR.Read(outer_i, [], T.index, srcinfo), + T.index, + srcinfo, + ), + N.type, + srcinfo, + ), + ], + T.size, + srcinfo, + ), body, loop.loop_mode, srcinfo, ) - ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper_min, "body") fwd = _compose(fwd_wrap, fwd) # replace the iteration variable in the body @@ -874,40 +1354,13 @@ def mk_main_iter(c): only_replace_attrs=False, ) - # add the tail case - if tail_strategy in ["cut", "cut_and_guard"]: - cut_i = Sym(inner_iter) - Ntail = szop("%", N, inner_hi) + fuzz(loop_cursor.parent(), fwd) + return ir, fwd - # in the tail loop we want the iteration variable to - # be mapped instead to (Ncut*Q + cut_i) - cut_tail_sub = szop("+", rd(cut_i), szop("*", outer_hi, inner_hi)) - cut_body = Alpha_Rename(loop.body).result() - env = {loop.iter: cut_tail_sub} - cut_body = SubstArgs(cut_body, env).result() - - cut_s = LoopIR.For( - cut_i, - LoopIR.Const(0, T.index, srcinfo), - Ntail, - cut_body, - loop.loop_mode, - srcinfo, - ) - if tail_strategy == "cut_and_guard": - cond = boolop(">", Ntail, LoopIR.Const(0, T.int, srcinfo), T.bool) - cut_s = LoopIR.If(cond, [cut_s], [], srcinfo) - - ir, fwd_ins = fwd(loop_cursor).after()._insert([cut_s]) - fwd = _compose(fwd_ins, fwd) - - return ir, fwd - - -# --------------------------------------------------------------------------- # -# --------------------------------------------------------------------------- # -# Unroll scheduling directive +# --------------------------------------------------------------------------- # +# --------------------------------------------------------------------------- # +# Unroll scheduling directive def DoUnroll(c_loop): @@ -1181,20 +1634,41 @@ def mk_write(c): return ir, fwd -def DoConfigWrite(stmt_cursor, config, field, expr, before=False): - assert isinstance(expr, (LoopIR.Read, LoopIR.StrideExpr, LoopIR.Const)) - s = stmt_cursor._node +def DoConfigWrite( + stmt_cursor, config, field, expr, check_mode: CheckMode, before=False +): + def static_check(): + assert isinstance(expr, (LoopIR.Read, LoopIR.StrideExpr, LoopIR.Const)) + s = stmt_cursor._node + + cw_s = LoopIR.WriteConfig(config, field, expr, s.srcinfo) + + if before: + ir, fwd = stmt_cursor.before()._insert([cw_s]) + else: + ir, fwd = stmt_cursor.after()._insert([cw_s]) - cw_s = LoopIR.WriteConfig(config, field, expr, s.srcinfo) + cfg = Check_DeleteConfigWrite(ir, [cw_s]) + return ir, fwd, cfg - if before: - ir, fwd = stmt_cursor.before()._insert([cw_s]) - else: - ir, fwd = stmt_cursor.after()._insert([cw_s]) + def dynamic_check(): + assert isinstance(expr, (LoopIR.Read, LoopIR.StrideExpr, LoopIR.Const)) + s = stmt_cursor._node + + cw_s = LoopIR.WriteConfig(config, field, expr, s.srcinfo) - cfg = Check_DeleteConfigWrite(ir, [cw_s]) + if before: + ir, fwd1 = stmt_cursor.before()._insert([LoopIR.Pass(s.srcinfo)]) + pass_cursor = fwd1(stmt_cursor).prev() + else: + ir, fwd1 = stmt_cursor.after()._insert([LoopIR.Pass(s.srcinfo)]) + pass_cursor = fwd1(stmt_cursor).next() + ir, fwd2 = pass_cursor._replace([cw_s]) - return ir, fwd, cfg + fuzz(pass_cursor, fwd2) + return ir, _compose(fwd2, fwd1), None + + return do_check(static_check, dynamic_check, check_mode) # --------------------------------------------------------------------------- # @@ -1228,7 +1702,6 @@ def static_check(): def dynamic_check(): e = expr_cursor._node - assert isinstance(e, LoopIR.Read) c = expr_cursor while not isinstance(c._node, LoopIR.stmt): @@ -1299,212 +1772,422 @@ def match_parent(c1, c2): return c1, c2 -def DoRewriteExpr(expr_cursor, new_expr): - proc = expr_cursor.get_root() - s = get_enclosing_stmt_cursor(expr_cursor)._node - Check_ExprEqvInContext(proc, expr_cursor._node, [s], new_expr, [s]) - return expr_cursor._replace(new_expr) +def DoRewriteExpr(expr_cursor, new_expr, check_mode): + ir, fwd = expr_cursor._replace(new_expr) + def static_check(): + proc = expr_cursor.get_root() + s = get_enclosing_stmt_cursor(expr_cursor)._node + Check_ExprEqvInContext(proc, expr_cursor._node, [s], new_expr, [s]) -def DoBindExpr(new_name, expr_cursors): - assert expr_cursors + def dynamic_check(): + fuzz(get_enclosing_stmt_cursor(expr_cursor), fwd) - expr = expr_cursors[0]._node - assert isinstance(expr, LoopIR.expr) - assert expr.type.is_numeric() + do_check(static_check, dynamic_check, check_mode) + return ir, fwd - expr_reads = [name for (name, typ) in get_reads_of_expr(expr)] - # TODO: dirty hack. need real CSE-equality (i.e. modulo srcinfo) - expr_cursors = [c for c in expr_cursors if str(c._node) == str(expr)] - init_s = get_enclosing_stmt_cursor(expr_cursors[0]) - if len(expr_cursors) > 1: - # TODO: Currently assume expr cursors is sorted in order - init_s, _ = match_parent(init_s, expr_cursors[-1]) +def DoBindExpr(new_name, expr_cursors, check_mode: CheckMode): + def static_check(): + assert expr_cursors - new_name = Sym(new_name) - alloc_s = LoopIR.Alloc(new_name, expr.type.basetype(), DRAM, expr.srcinfo) - assign_s = LoopIR.Assign(new_name, expr.type.basetype(), [], expr, expr.srcinfo) - ir, fwd = init_s.before()._insert([alloc_s, assign_s]) - - new_read = LoopIR.Read(new_name, [], expr.type, expr.srcinfo) - first_write_c = None - for c in get_rest_of_block(init_s, inclusive=True): - for block in match_pattern(c, "_ = _") + match_pattern(c, "_ += _"): - assert len(block) == 1 - sc = block[0] - if sc._node.name in expr_reads: - first_write_c = sc + expr = expr_cursors[0]._node + assert isinstance(expr, LoopIR.expr) + assert expr.type.is_numeric() + + expr_reads = [name for (name, typ) in get_reads_of_expr(expr)] + # TODO: dirty hack. need real CSE-equality (i.e. modulo srcinfo) + expr_cursors_eq = [c for c in expr_cursors if str(c._node) == str(expr)] + + init_s = get_enclosing_stmt_cursor(expr_cursors_eq[0]) + if len(expr_cursors_eq) > 1: + # TODO: Currently assume expr cursors is sorted in order + init_s, _ = match_parent(init_s, expr_cursors_eq[-1]) + + new_name_sym = Sym(new_name) + alloc_s = LoopIR.Alloc(new_name_sym, expr.type.basetype(), DRAM, expr.srcinfo) + assign_s = LoopIR.Assign( + new_name_sym, expr.type.basetype(), [], expr, expr.srcinfo + ) + ir, fwd = init_s.before()._insert([alloc_s, assign_s]) + + new_read = LoopIR.Read(new_name_sym, [], expr.type, expr.srcinfo) + first_write_c = None + for c in get_rest_of_block(init_s, inclusive=True): + for block in match_pattern(c, "_ = _") + match_pattern(c, "_ += _"): + assert len(block) == 1 + sc = block[0] + if sc._node.name in expr_reads: + first_write_c = sc + break + + if first_write_c and isinstance(c._node, (LoopIR.For, LoopIR.If)): + # Potentially unsafe to partially bind, err on side of caution for now + break + + while expr_cursors_eq and c.is_ancestor_of(expr_cursors_eq[0]): + ir, fwd_repl = _replace_helper( + fwd(expr_cursors_eq[0]), new_read, only_replace_attrs=False + ) + fwd = _compose(fwd_repl, fwd) + expr_cursors_eq.pop(0) + + if first_write_c: break - if first_write_c and isinstance(c._node, (LoopIR.For, LoopIR.If)): - # Potentially unsafe to partially bind, err on side of caution for now - break + if len(expr_cursors_eq) > 0: + raise SchedulingError("Unsafe to bind all of the provided exprs.") - while expr_cursors and c.is_ancestor_of(expr_cursors[0]): - ir, fwd_repl = _replace_helper( - fwd(expr_cursors[0]), new_read, only_replace_attrs=False - ) - fwd = _compose(fwd_repl, fwd) - expr_cursors.pop(0) + Check_Aliasing(ir) + return ir, fwd - if first_write_c: - break + def dynamic_check(): + assert expr_cursors + + expr = expr_cursors[0]._node + assert isinstance(expr, LoopIR.expr) + et = expr.type if expr.type.is_numeric() else T.i32 + + expr_reads = [name for (name, typ) in get_reads_of_expr(expr)] + # TODO: dirty hack. need real CSE-equality (i.e. modulo srcinfo) + expr_cursors_eq = [c for c in expr_cursors if str(c._node) == str(expr)] + + init_s = get_enclosing_stmt_cursor(expr_cursors_eq[0]) + if len(expr_cursors_eq) > 1: + # TODO: Currently assume expr cursors is sorted in order + init_s, _ = match_parent(init_s, expr_cursors_eq[-1]) + + new_name_sym = Sym(new_name) + alloc_s = LoopIR.Alloc(new_name_sym, et, DRAM, expr.srcinfo) + assign_s = LoopIR.Assign(new_name_sym, et, [], expr, expr.srcinfo) + ir, fwd1 = init_s.before()._insert([LoopIR.Pass(expr.srcinfo)]) + pass_cursor = fwd1(init_s).prev() + ir, fwd2 = pass_cursor._replace([alloc_s, assign_s]) + + new_read = LoopIR.Read(new_name_sym, [], et, expr.srcinfo) + for c in get_rest_of_block(init_s, inclusive=True): + while expr_cursors_eq and c.is_ancestor_of(expr_cursors_eq[0]): + ir, fwd_repl = _replace_helper( + fwd2(fwd1(expr_cursors_eq[0])), new_read, only_replace_attrs=False + ) + fwd2 = _compose(fwd_repl, fwd2) + expr_cursors_eq.pop(0) - if len(expr_cursors) > 0: - raise SchedulingError("Unsafe to bind all of the provided exprs.") + if len(expr_cursors_eq) > 0: + raise SchedulingError("Unsafe to bind all of the provided exprs.") - Check_Aliasing(ir) - return ir, fwd + fuzz(get_rest_of_block(pass_cursor, inclusive=True), fwd2) + return ir, _compose(fwd2, fwd1) + + return do_check(static_check, dynamic_check, check_mode) -def DoLiftScope(inner_c): - inner_s = inner_c._node - assert isinstance(inner_s, (LoopIR.If, LoopIR.For)) - target_type = "if statement" if isinstance(inner_s, LoopIR.If) else "for loop" +def DoLiftScope(inner_c, check_mode: CheckMode): + def static_check(): + inner_s = inner_c._node + assert isinstance(inner_s, (LoopIR.If, LoopIR.For)) + target_type = "if statement" if isinstance(inner_s, LoopIR.If) else "for loop" + + outer_c = inner_c.parent() + if outer_c.root() == outer_c: + raise SchedulingError("Cannot lift scope of top-level statement") + outer_s = outer_c._node + + ir, fwd = inner_c.get_root(), lambda x: x + + if isinstance(outer_s, LoopIR.If): + + def if_wrapper(body, insert_orelse=False): + src = outer_s.srcinfo + # this is needed because _replace expects a non-zero length block + orelse = [LoopIR.Pass(src)] if insert_orelse else [] + return LoopIR.If(outer_s.cond, body, orelse, src) + + def orelse_wrapper(orelse): + src = outer_s.srcinfo + body = [LoopIR.Pass(src)] + return LoopIR.If(outer_s.cond, body, orelse, src) + + if isinstance(inner_s, LoopIR.If): + if inner_s in outer_s.body: + # if INNER: + # if OUTER: if OUTER: A + # if INNER: A else: C + # else: B ~> else: + # else: C if OUTER: B + # else: C + if len(outer_s.body) > 1: + raise SchedulingError( + f"expected {target_type} to be directly nested in parent" + ) - outer_c = inner_c.parent() - if outer_c.root() == outer_c: - raise SchedulingError("Cannot lift scope of top-level statement") - outer_s = outer_c._node + blk_c = outer_s.orelse + wrapper = lambda body: if_wrapper(body, insert_orelse=bool(blk_c)) - ir, fwd = inner_c.get_root(), lambda x: x + ir, fwd = inner_c.body()._wrap(wrapper, "body") + if blk_c: + ir, fwd_repl = fwd(inner_c).body()[0].orelse()._replace(blk_c) + fwd = _compose(fwd_repl, fwd) - if isinstance(outer_s, LoopIR.If): + if inner_s.orelse: + ir, fwd_wrap = fwd(inner_c).orelse()._wrap(wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + if blk_c: + ir, fwd_repl = ( + fwd(inner_c).orelse()[0].orelse()._replace(blk_c) + ) + fwd = _compose(fwd_repl, fwd) + else: + # if INNER: + # if OUTER: A if OUTER: A + # else: else: B + # if INNER: B ~> else: + # else: C if OUTER: A + # else: C + assert inner_s in outer_s.orelse + if len(outer_s.orelse) > 1: + raise SchedulingError( + f"expected {target_type} to be directly nested in parent" + ) - def if_wrapper(body, insert_orelse=False): - src = outer_s.srcinfo - # this is needed because _replace expects a non-zero length block - orelse = [LoopIR.Pass(src)] if insert_orelse else [] - return LoopIR.If(outer_s.cond, body, orelse, src) + blk_a = outer_s.body - def orelse_wrapper(orelse): - src = outer_s.srcinfo - body = [LoopIR.Pass(src)] - return LoopIR.If(outer_s.cond, body, orelse, src) + ir, fwd = inner_c.body()._wrap(orelse_wrapper, "orelse") + ir, fwd_repl = fwd(inner_c).body()[0].body()._replace(blk_a) + fwd = _compose(fwd_repl, fwd) - if isinstance(inner_s, LoopIR.If): - if inner_s in outer_s.body: - # if INNER: - # if OUTER: if OUTER: A - # if INNER: A else: C - # else: B ~> else: - # else: C if OUTER: B - # else: C + if inner_s.orelse: + ir, fwd_wrap = ( + fwd(inner_c).orelse()._wrap(orelse_wrapper, "orelse") + ) + fwd = _compose(fwd_wrap, fwd) + ir, fwd_repl = fwd(inner_c).orelse()[0].body()._replace(blk_a) + fwd = _compose(fwd_repl, fwd) + elif isinstance(inner_s, LoopIR.For): + # if OUTER: for INNER in _: + # for INNER in _: A ~> if OUTER: A if len(outer_s.body) > 1: raise SchedulingError( f"expected {target_type} to be directly nested in parent" ) - blk_c = outer_s.orelse - wrapper = lambda body: if_wrapper(body, insert_orelse=bool(blk_c)) + if outer_s.orelse: + raise SchedulingError( + "cannot lift for loop when if has an orelse clause" + ) - ir, fwd = inner_c.body()._wrap(wrapper, "body") - if blk_c: - ir, fwd_repl = fwd(inner_c).body()[0].orelse()._replace(blk_c) - fwd = _compose(fwd_repl, fwd) + ir, fwd = inner_c.body()._move(inner_c.after()) + ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_move = fwd(outer_c)._move(fwd(inner_c).body()[0].after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_del = fwd(inner_c).body()[0]._delete() + fwd = _compose(fwd_del, fwd) + + return ir, fwd + + elif isinstance(outer_s, LoopIR.For): + if len(outer_s.body) > 1: + raise SchedulingError( + f"expected {target_type} to be directly nested in parent" + ) + + def loop_wrapper(body): + return outer_s.update(body=body) + + if isinstance(inner_s, LoopIR.If): + # for OUTER in _: if INNER: + # if INNER: A ~> for OUTER in _: A + # else: B else: + # for OUTER in _: B + if outer_s.iter in _FV(inner_s.cond): + raise SchedulingError("if statement depends on iteration variable") + + ir, fwd = inner_c.body()._wrap(loop_wrapper, "body") if inner_s.orelse: - ir, fwd_wrap = fwd(inner_c).orelse()._wrap(wrapper, "body") + ir, fwd_wrap = fwd(inner_c).orelse()._wrap(loop_wrapper, "body") fwd = _compose(fwd_wrap, fwd) + elif isinstance(inner_s, LoopIR.For): + # for OUTER in _: for INNER in _: + # for INNER in _: A ~> for OUTER in _: A + reads = get_reads_of_expr(inner_s.lo) + get_reads_of_expr(inner_s.hi) + if outer_s.iter in [name for name, _ in reads]: + raise SchedulingError( + "inner loop's lo or hi depends on outer loop's iteration variable" + ) + + Check_ReorderLoops(inner_c.get_root(), outer_s) + body = inner_c.body() + ir, fwd = inner_c._move(outer_c.after()) + ir, fwd_move = fwd(outer_c)._move(fwd(body).before()) + fwd = _compose(fwd_move, fwd) + ir, fwd_move = fwd(body)._move(fwd(outer_c).body().after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_del = fwd(outer_c).body()[0]._delete() + fwd = _compose(fwd_del, fwd) + return ir, fwd + + ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_del = fwd(outer_c)._delete() + fwd = _compose(fwd_del, fwd) + + return ir, fwd + + def dynamic_check(): + inner_s = inner_c._node + assert isinstance(inner_s, (LoopIR.If, LoopIR.For)) + target_type = "if statement" if isinstance(inner_s, LoopIR.If) else "for loop" + + outer_c = inner_c.parent() + if outer_c.root() == outer_c: + raise SchedulingError("Cannot lift scope of top-level statement") + outer_s = outer_c._node + + ir, fwd = inner_c.get_root(), lambda x: x + + if isinstance(outer_s, LoopIR.If): + + def if_wrapper(body, insert_orelse=False): + src = outer_s.srcinfo + # this is needed because _replace expects a non-zero length block + orelse = [LoopIR.Pass(src)] if insert_orelse else [] + return LoopIR.If(outer_s.cond, body, orelse, src) + + def orelse_wrapper(orelse): + src = outer_s.srcinfo + body = [LoopIR.Pass(src)] + return LoopIR.If(outer_s.cond, body, orelse, src) + + if isinstance(inner_s, LoopIR.If): + if inner_s in outer_s.body: + # if INNER: + # if OUTER: if OUTER: A + # if INNER: A else: C + # else: B ~> else: + # else: C if OUTER: B + # else: C + if len(outer_s.body) > 1: + raise SchedulingError( + f"expected {target_type} to be directly nested in parent" + ) + + blk_c = outer_s.orelse + wrapper = lambda body: if_wrapper(body, insert_orelse=bool(blk_c)) + + ir, fwd = inner_c.body()._wrap(wrapper, "body") if blk_c: - ir, fwd_repl = fwd(inner_c).orelse()[0].orelse()._replace(blk_c) + ir, fwd_repl = fwd(inner_c).body()[0].orelse()._replace(blk_c) fwd = _compose(fwd_repl, fwd) - else: - # if INNER: - # if OUTER: A if OUTER: A - # else: else: B - # if INNER: B ~> else: - # else: C if OUTER: A - # else: C - assert inner_s in outer_s.orelse - if len(outer_s.orelse) > 1: + + if inner_s.orelse: + ir, fwd_wrap = fwd(inner_c).orelse()._wrap(wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + if blk_c: + ir, fwd_repl = ( + fwd(inner_c).orelse()[0].orelse()._replace(blk_c) + ) + fwd = _compose(fwd_repl, fwd) + else: + # if INNER: + # if OUTER: A if OUTER: A + # else: else: B + # if INNER: B ~> else: + # else: C if OUTER: A + # else: C + assert inner_s in outer_s.orelse + if len(outer_s.orelse) > 1: + raise SchedulingError( + f"expected {target_type} to be directly nested in parent" + ) + + blk_a = outer_s.body + + ir, fwd = inner_c.body()._wrap(orelse_wrapper, "orelse") + ir, fwd_repl = fwd(inner_c).body()[0].body()._replace(blk_a) + fwd = _compose(fwd_repl, fwd) + + if inner_s.orelse: + ir, fwd_wrap = ( + fwd(inner_c).orelse()._wrap(orelse_wrapper, "orelse") + ) + fwd = _compose(fwd_wrap, fwd) + ir, fwd_repl = fwd(inner_c).orelse()[0].body()._replace(blk_a) + fwd = _compose(fwd_repl, fwd) + elif isinstance(inner_s, LoopIR.For): + # if OUTER: for INNER in _: + # for INNER in _: A ~> if OUTER: A + if len(outer_s.body) > 1: raise SchedulingError( f"expected {target_type} to be directly nested in parent" ) - blk_a = outer_s.body + if outer_s.orelse: + raise SchedulingError( + "cannot lift for loop when if has an orelse clause" + ) - ir, fwd = inner_c.body()._wrap(orelse_wrapper, "orelse") - ir, fwd_repl = fwd(inner_c).body()[0].body()._replace(blk_a) - fwd = _compose(fwd_repl, fwd) + ir, fwd = inner_c.body()._move(inner_c.after()) + ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_move = fwd(outer_c)._move(fwd(inner_c).body()[0].after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_del = fwd(inner_c).body()[0]._delete() + fwd = _compose(fwd_del, fwd) - if inner_s.orelse: - ir, fwd_wrap = fwd(inner_c).orelse()._wrap(orelse_wrapper, "orelse") - fwd = _compose(fwd_wrap, fwd) - ir, fwd_repl = fwd(inner_c).orelse()[0].body()._replace(blk_a) - fwd = _compose(fwd_repl, fwd) - elif isinstance(inner_s, LoopIR.For): - # if OUTER: for INNER in _: - # for INNER in _: A ~> if OUTER: A + return ir, fwd + + elif isinstance(outer_s, LoopIR.For): if len(outer_s.body) > 1: raise SchedulingError( f"expected {target_type} to be directly nested in parent" ) - if outer_s.orelse: - raise SchedulingError( - "cannot lift for loop when if has an orelse clause" - ) + def loop_wrapper(body): + return outer_s.update(body=body) - ir, fwd = inner_c.body()._move(inner_c.after()) - ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_move = fwd(outer_c)._move(fwd(inner_c).body()[0].after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(inner_c).body()[0]._delete() - fwd = _compose(fwd_del, fwd) + if isinstance(inner_s, LoopIR.If): + # for OUTER in _: if INNER: + # if INNER: A ~> for OUTER in _: A + # else: B else: + # for OUTER in _: B + if outer_s.iter in _FV(inner_s.cond): + raise SchedulingError("if statement depends on iteration variable") - return ir, fwd + ir, fwd = inner_c.body()._wrap(loop_wrapper, "body") - elif isinstance(outer_s, LoopIR.For): - if len(outer_s.body) > 1: - raise SchedulingError( - f"expected {target_type} to be directly nested in parent" - ) - - def loop_wrapper(body): - return outer_s.update(body=body) - - if isinstance(inner_s, LoopIR.If): - # for OUTER in _: if INNER: - # if INNER: A ~> for OUTER in _: A - # else: B else: - # for OUTER in _: B - if outer_s.iter in _FV(inner_s.cond): - raise SchedulingError("if statement depends on iteration variable") - - ir, fwd = inner_c.body()._wrap(loop_wrapper, "body") + if inner_s.orelse: + ir, fwd_wrap = fwd(inner_c).orelse()._wrap(loop_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + elif isinstance(inner_s, LoopIR.For): + # for OUTER in _: for INNER in _: + # for INNER in _: A ~> for OUTER in _: A + reads = get_reads_of_expr(inner_s.lo) + get_reads_of_expr(inner_s.hi) + if outer_s.iter in [name for name, _ in reads]: + raise SchedulingError( + "inner loop's lo or hi depends on outer loop's iteration variable" + ) - if inner_s.orelse: - ir, fwd_wrap = fwd(inner_c).orelse()._wrap(loop_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - elif isinstance(inner_s, LoopIR.For): - # for OUTER in _: for INNER in _: - # for INNER in _: A ~> for OUTER in _: A - reads = get_reads_of_expr(inner_s.lo) + get_reads_of_expr(inner_s.hi) - if outer_s.iter in [name for name, _ in reads]: - raise SchedulingError( - "inner loop's lo or hi depends on outer loop's iteration variable" - ) + body = inner_c.body() + ir, fwd = inner_c._move(outer_c.after()) + ir, fwd_move = fwd(outer_c)._move(fwd(body).before()) + fwd = _compose(fwd_move, fwd) + ir, fwd_move = fwd(body)._move(fwd(outer_c).body().after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_del = fwd(outer_c).body()[0]._delete() + fwd = _compose(fwd_del, fwd) + return ir, fwd - Check_ReorderLoops(inner_c.get_root(), outer_s) - body = inner_c.body() - ir, fwd = inner_c._move(outer_c.after()) - ir, fwd_move = fwd(outer_c)._move(fwd(body).before()) - fwd = _compose(fwd_move, fwd) - ir, fwd_move = fwd(body)._move(fwd(outer_c).body().after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(outer_c).body()[0]._delete() - fwd = _compose(fwd_del, fwd) - return ir, fwd + ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_del = fwd(outer_c)._delete() + fwd = _compose(fwd_del, fwd) + fuzz(outer_c.parent(), fwd) - ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(outer_c)._delete() - fwd = _compose(fwd_del, fwd) + return ir, fwd - return ir, fwd + return do_check(static_check, dynamic_check, check_mode) def DoLiftConstant(assign_c, loop_c): @@ -1636,115 +2319,228 @@ def reduces_have_same_constant(s1, s2): return ir, fwd -def DoExpandDim(alloc_cursor, alloc_dim, indexing): +def DoExpandDim(alloc_cursor, alloc_dim, indexing, check_mode: CheckMode): alloc_s = alloc_cursor._node assert isinstance(alloc_s, LoopIR.Alloc) assert isinstance(alloc_dim, LoopIR.expr) assert isinstance(indexing, LoopIR.expr) - Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], alloc_dim) + def static_check(): + Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], alloc_dim) - old_typ = alloc_s.type - new_rngs = [alloc_dim] - if isinstance(old_typ, T.Tensor): - new_rngs += old_typ.shape() - basetyp = old_typ.basetype() - new_typ = T.Tensor(new_rngs, False, basetyp) - new_alloc = alloc_s.update(type=new_typ) + old_typ = alloc_s.type + new_rngs = [alloc_dim] + if isinstance(old_typ, T.Tensor): + new_rngs += old_typ.shape() + basetyp = old_typ.basetype() + new_typ = T.Tensor(new_rngs, False, basetyp) + new_alloc = alloc_s.update(type=new_typ) - ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) + ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) - def mk_read(c): - rd = c._node + def mk_read(c): + rd = c._node - # TODO: do I need to worry about Builtins too? - if isinstance(c.parent()._node, (LoopIR.Call)) and not rd.idx: - raise SchedulingError( - "TODO: Please Contact the developers to fix (i.e. add) " - "support for passing windows to scalar arguments" - ) + # TODO: do I need to worry about Builtins too? + if isinstance(c.parent()._node, (LoopIR.Call)) and not rd.idx: + raise SchedulingError( + "TODO: Please Contact the developers to fix (i.e. add) " + "support for passing windows to scalar arguments" + ) - if isinstance(rd, LoopIR.Read): - return {"idx": [indexing] + rd.idx} - elif isinstance(rd, LoopIR.WindowExpr): - return {"idx": [LoopIR.Point(indexing, rd.srcinfo)] + rd.idx} - else: - raise NotImplementedError( - f"Did not implement {type(rd)}. This may be a bug." - ) + if isinstance(rd, LoopIR.Read): + return {"idx": [indexing] + rd.idx} + elif isinstance(rd, LoopIR.WindowExpr): + return {"idx": [LoopIR.Point(indexing, rd.srcinfo)] + rd.idx} + else: + raise NotImplementedError( + f"Did not implement {type(rd)}. This may be a bug." + ) - def mk_write(c): - s = c._node - return {"idx": [indexing] + s.idx} + def mk_write(c): + s = c._node + return {"idx": [indexing] + s.idx} - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) - after_alloc = [c._node for c in get_rest_of_block(fwd(alloc_cursor))] - Check_Bounds(ir, new_alloc, after_alloc) + after_alloc = [c._node for c in get_rest_of_block(fwd(alloc_cursor))] - return ir, fwd + Check_Bounds(ir, new_alloc, after_alloc) + return ir, fwd + + def dynamic_check(): + + old_typ = alloc_s.type + new_rngs = [alloc_dim] + if isinstance(old_typ, T.Tensor): + new_rngs += old_typ.shape() + basetyp = old_typ.basetype() + new_typ = T.Tensor(new_rngs, False, basetyp) + new_alloc = alloc_s.update(type=new_typ) + + ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) + + def mk_read(c): + rd = c._node + + # TODO: do I need to worry about Builtins too? + if isinstance(c.parent()._node, (LoopIR.Call)) and not rd.idx: + raise SchedulingError( + "TODO: Please Contact the developers to fix (i.e. add) " + "support for passing windows to scalar arguments" + ) + + if isinstance(rd, LoopIR.Read): + return {"idx": [indexing] + rd.idx} + elif isinstance(rd, LoopIR.WindowExpr): + return {"idx": [LoopIR.Point(indexing, rd.srcinfo)] + rd.idx} + else: + raise NotImplementedError( + f"Did not implement {type(rd)}. This may be a bug." + ) + + def mk_write(c): + s = c._node + return {"idx": [indexing] + s.idx} + + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) + after_alloc = [c._node for c in get_rest_of_block(fwd(alloc_cursor))] -def DoResizeDim(alloc_cursor, dim_idx: int, size: LoopIR.expr, offset: LoopIR.expr): + fuzz(get_rest_of_block(alloc_cursor), fwd) + return ir, fwd + + return do_check(static_check, dynamic_check, check_mode) + + +def DoResizeDim( + alloc_cursor, + dim_idx: int, + size: LoopIR.expr, + offset: LoopIR.expr, + check_mode: CheckMode, +): alloc_s = alloc_cursor._node alloc_name = alloc_s.name assert isinstance(alloc_s, LoopIR.Alloc) assert isinstance(alloc_s.type, T.Tensor) - Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], size) + def static_check(): + Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], size) - ir, fwd = ( - alloc_cursor._child_node("type")._child_block("hi")[dim_idx]._replace([size]) - ) + ir, fwd = ( + alloc_cursor._child_node("type") + ._child_block("hi")[dim_idx] + ._replace([size]) + ) - def mk_read(c): - rd = c._node + def mk_read(c): + rd = c._node - def mk_binop(e): - return LoopIR.BinOp("-", e, offset, offset.type, rd.srcinfo) + def mk_binop(e): + return LoopIR.BinOp("-", e, offset, offset.type, rd.srcinfo) - new_idx = rd.idx.copy() - if isinstance(rd, LoopIR.Read): - new_idx[dim_idx] = mk_binop(rd.idx[dim_idx]) - return {"idx": new_idx} + new_idx = rd.idx.copy() + if isinstance(rd, LoopIR.Read): + new_idx[dim_idx] = mk_binop(rd.idx[dim_idx]) + return {"idx": new_idx} - elif isinstance(rd, LoopIR.WindowExpr): - if isinstance(rd.idx[dim_idx], LoopIR.Point): - new_idx[dim_idx] = LoopIR.Point( - mk_binop(rd.idx[dim_idx].pt), rd.srcinfo - ) + elif isinstance(rd, LoopIR.WindowExpr): + if isinstance(rd.idx[dim_idx], LoopIR.Point): + new_idx[dim_idx] = LoopIR.Point( + mk_binop(rd.idx[dim_idx].pt), rd.srcinfo + ) + else: + new_idx[dim_idx] = LoopIR.Interval( + mk_binop(rd.idx[dim_idx].lo), + mk_binop(rd.idx[dim_idx].hi), + rd.srcinfo, + ) + + return {"idx": new_idx} else: - new_idx[dim_idx] = LoopIR.Interval( - mk_binop(rd.idx[dim_idx].lo), - mk_binop(rd.idx[dim_idx].hi), - rd.srcinfo, + raise NotImplementedError( + f"Did not implement {type(rd)}. This may be a bug." ) - return {"idx": new_idx} - else: - raise NotImplementedError( - f"Did not implement {type(rd)}. This may be a bug." + def mk_write(c): + s = c._node + new_idx = s.idx.copy() + new_idx[dim_idx] = LoopIR.BinOp( + "-", s.idx[dim_idx], offset, offset.type, s.srcinfo ) + return {"idx": new_idx} - def mk_write(c): - s = c._node - new_idx = s.idx.copy() - new_idx[dim_idx] = LoopIR.BinOp( - "-", s.idx[dim_idx], offset, offset.type, s.srcinfo + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) + + new_alloc_cursor = fwd(alloc_cursor) + after_alloc = [c._node for c in get_rest_of_block(new_alloc_cursor)] + + Check_Bounds(ir, new_alloc_cursor._node, after_alloc) + return ir, fwd + + def dynamic_check(): + + ir, fwd = ( + alloc_cursor._child_node("type") + ._child_block("hi")[dim_idx] + ._replace([size]) ) - return {"idx": new_idx} - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) + def mk_read(c): + rd = c._node - alloc_cursor = fwd(alloc_cursor) - after_alloc = [c._node for c in get_rest_of_block(alloc_cursor)] - Check_Bounds(ir, alloc_cursor._node, after_alloc) + def mk_binop(e): + return LoopIR.BinOp("-", e, offset, offset.type, rd.srcinfo) - return ir, fwd + new_idx = rd.idx.copy() + if isinstance(rd, LoopIR.Read): + new_idx[dim_idx] = mk_binop(rd.idx[dim_idx]) + return {"idx": new_idx} + + elif isinstance(rd, LoopIR.WindowExpr): + if isinstance(rd.idx[dim_idx], LoopIR.Point): + new_idx[dim_idx] = LoopIR.Point( + mk_binop(rd.idx[dim_idx].pt), rd.srcinfo + ) + else: + new_idx[dim_idx] = LoopIR.Interval( + mk_binop(rd.idx[dim_idx].lo), + mk_binop(rd.idx[dim_idx].hi), + rd.srcinfo, + ) + + return {"idx": new_idx} + else: + raise NotImplementedError( + f"Did not implement {type(rd)}. This may be a bug." + ) + + def mk_write(c): + s = c._node + new_idx = s.idx.copy() + new_idx[dim_idx] = LoopIR.BinOp( + "-", s.idx[dim_idx], offset, offset.type, s.srcinfo + ) + return {"idx": new_idx} + + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) + + new_alloc_cursor = fwd(alloc_cursor) + after_alloc = [c._node for c in get_rest_of_block(new_alloc_cursor)] + + fuzz(get_rest_of_block(alloc_cursor), fwd) + return ir, fwd + + return do_check(static_check, dynamic_check, check_mode) def DoRearrangeDim(decl_cursor, permute_vector): @@ -1826,7 +2622,7 @@ def mk_stride_expr(c): return ir, fwd -def DoDivideDim(alloc_cursor, dim_idx, quotient): +def DoDivideDim(alloc_cursor, dim_idx, quotient, check_mode: CheckMode): alloc_s = alloc_cursor._node alloc_sym = alloc_s.name @@ -1837,52 +2633,102 @@ def DoDivideDim(alloc_cursor, dim_idx, quotient): old_typ = alloc_s.type old_shp = old_typ.shape() dim = old_shp[dim_idx] - Check_IsDivisible(alloc_cursor.get_root(), [alloc_s], dim, quotient) - numer = divide_expr(dim, quotient) - new_shp = ( - old_shp[:dim_idx] - + [ - numer, - LoopIR.Const(quotient, T.int, dim.srcinfo), - ] - + old_shp[dim_idx + 1 :] - ) - new_typ = T.Tensor(new_shp, False, old_typ.basetype()) - ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) + def static_check(): + Check_IsDivisible(alloc_cursor.get_root(), [alloc_s], dim, quotient) + numer = divide_expr(dim, quotient) + new_shp = ( + old_shp[:dim_idx] + + [ + numer, + LoopIR.Const(quotient, T.int, dim.srcinfo), + ] + + old_shp[dim_idx + 1 :] + ) + new_typ = T.Tensor(new_shp, False, old_typ.basetype()) - def remap_idx(idx): - orig_i = idx[dim_idx] - srcinfo = orig_i.srcinfo - quot = LoopIR.Const(quotient, T.int, srcinfo) - hi = LoopIR.BinOp("/", orig_i, quot, orig_i.type, srcinfo) - lo = LoopIR.BinOp("%", orig_i, quot, orig_i.type, srcinfo) - return idx[:dim_idx] + [hi, lo] + idx[dim_idx + 1 :] + ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) - def mk_read(c): - rd = c._node + def remap_idx(idx): + orig_i = idx[dim_idx] + srcinfo = orig_i.srcinfo + quot = LoopIR.Const(quotient, T.int, srcinfo) + hi = LoopIR.BinOp("/", orig_i, quot, orig_i.type, srcinfo) + lo = LoopIR.BinOp("%", orig_i, quot, orig_i.type, srcinfo) + return idx[:dim_idx] + [hi, lo] + idx[dim_idx + 1 :] - if isinstance(rd, LoopIR.Read) and not rd.idx: - raise SchedulingError( - f"Cannot divide {alloc_sym} because buffer is passed as an argument" - ) - elif isinstance(rd, LoopIR.WindowExpr): - raise SchedulingError( - f"Cannot divide {alloc_sym} because the buffer is windowed later on" - ) + def mk_read(c): + rd = c._node - return {"idx": remap_idx(rd.idx)} + if isinstance(rd, LoopIR.Read) and not rd.idx: + raise SchedulingError( + f"Cannot divide {alloc_sym} because buffer is passed as an argument" + ) + elif isinstance(rd, LoopIR.WindowExpr): + raise SchedulingError( + f"Cannot divide {alloc_sym} because the buffer is windowed later on" + ) - def mk_write(c): - s = c._node - return {"idx": remap_idx(s.idx)} + return {"idx": remap_idx(rd.idx)} - # TODO: add better iteration primitive - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) + def mk_write(c): + s = c._node + return {"idx": remap_idx(s.idx)} - return ir, fwd + # TODO: add better iteration primitive + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) + return ir, fwd + + def dynamic_check(): + numer = divide_expr(dim, quotient) + new_shp = ( + old_shp[:dim_idx] + + [ + numer, + LoopIR.Const(quotient, T.int, dim.srcinfo), + ] + + old_shp[dim_idx + 1 :] + ) + new_typ = T.Tensor(new_shp, False, old_typ.basetype()) + + ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) + + def remap_idx(idx): + orig_i = idx[dim_idx] + srcinfo = orig_i.srcinfo + quot = LoopIR.Const(quotient, T.int, srcinfo) + hi = LoopIR.BinOp("/", orig_i, quot, orig_i.type, srcinfo) + lo = LoopIR.BinOp("%", orig_i, quot, orig_i.type, srcinfo) + return idx[:dim_idx] + [hi, lo] + idx[dim_idx + 1 :] + + def mk_read(c): + rd = c._node + + if isinstance(rd, LoopIR.Read) and not rd.idx: + raise SchedulingError( + f"Cannot divide {alloc_sym} because buffer is passed as an argument" + ) + elif isinstance(rd, LoopIR.WindowExpr): + raise SchedulingError( + f"Cannot divide {alloc_sym} because the buffer is windowed later on" + ) + + return {"idx": remap_idx(rd.idx)} + + def mk_write(c): + s = c._node + return {"idx": remap_idx(s.idx)} + + # TODO: add better iteration primitive + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) + fuzz(get_rest_of_block(alloc_cursor), fwd) + return ir, fwd + + return do_check(static_check, dynamic_check, check_mode) def DoMultiplyDim(alloc_cursor, hi_idx, lo_idx): @@ -2363,146 +3209,296 @@ def _stmt(s): return all(_stmt(s) for s in stmts) -def DoRemoveLoop(loop, unsafe_disable_check): - s = loop._node +def DoRemoveLoop(loop, unsafe_disable_check, check_mode): + def static_check(): + s = loop._node - # Check if we can remove the loop. Conditions are: - # 1. Body does not depend on the loop iteration variable - if s.iter in _FV(s.body): - raise SchedulingError( - f"Cannot remove loop, {s.iter} is not " "free in the loop body." - ) + # Check if we can remove the loop. Conditions are: + # 1. Body does not depend on the loop iteration variable + if s.iter in _FV(s.body): + raise SchedulingError( + f"Cannot remove loop, {s.iter} is not " "free in the loop body." + ) - # 2. Body is idempotent - if not unsafe_disable_check: - Check_IsIdempotent(loop.get_root(), [s]) + # 2. Body is idempotent + if not unsafe_disable_check: + Check_IsIdempotent(loop.get_root(), [s]) - # 3. The loop runs at least once; - # If not, then place a guard around the statement - ir, fwd = loop.get_root(), lambda x: x - try: - Check_IsPositiveExpr(loop.get_root(), [s], s.hi) - except SchedulingError: - cond = LoopIR.BinOp(">", s.hi, s.lo, T.bool, s.srcinfo) + # 3. The loop runs at least once; + # If not, then place a guard around the statement + ir, fwd = loop.get_root(), lambda x: x + try: + Check_IsPositiveExpr(loop.get_root(), [s], s.hi) + except SchedulingError: + cond = LoopIR.BinOp(">", s.hi, s.lo, T.bool, s.srcinfo) - def wrapper(body): - return LoopIR.If(cond, body, [], s.srcinfo) + def wrapper(body): + return LoopIR.If(cond, body, [], s.srcinfo) - ir, fwd = loop.body()._wrap(wrapper, "body") + ir, fwd = loop.body()._wrap(wrapper, "body") - ir, fwd_move = fwd(loop).body()._move(fwd(loop).after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(loop)._delete() - fwd = _compose(fwd_del, fwd) + ir, fwd_move = fwd(loop).body()._move(fwd(loop).after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_del = fwd(loop)._delete() + fwd = _compose(fwd_del, fwd) - return ir, fwd + return ir, fwd + + def dynamic_check(): + s = loop._node + + # Check if we can remove the loop. Conditions are: + # 1. Body does not depend on the loop iteration variable + if s.iter in _FV(s.body): + raise SchedulingError( + f"Cannot remove loop, {s.iter} is not " "free in the loop body." + ) + + # 2. Body is idempotent + + # 3. The loop runs at least once; + # If not, then place a guard around the statement + ir1, fwd1 = loop.get_root(), lambda x: x + ir1, fwd_move1 = fwd1(loop).body()._move(fwd1(loop).after()) + fwd1 = _compose(fwd_move1, fwd1) + ir1, fwd_del1 = fwd1(loop)._delete() + fwd1 = _compose(fwd_del1, fwd1) + try: + fuzz(loop.parent(), fwd1) + return ir1, fwd1 + except SchedulingError: + + def wrapper(body): + return LoopIR.If(cond, body, [], s.srcinfo) + + ir2, fwd2 = loop.body()._wrap(wrapper, "body") + ir2, fwd_move2 = fwd2(loop).body()._move(fwd2(loop).after()) + fwd2 = _compose(fwd_move2, fwd2) + ir2, fwd_del2 = fwd2(loop)._delete() + fwd2 = _compose(fwd_del2, fwd2) + cond = LoopIR.BinOp(">", s.hi, s.lo, T.bool, s.srcinfo) + fuzz(loop.parent(), fwd2) + return ir2, fwd2 + + return do_check(static_check, dynamic_check, check_mode) # This is same as original FissionAfter, except that # this does not remove loop. We have separate remove_loop # operator for that purpose. -def DoFissionAfterSimple(stmt_cursor, n_lifts, unsafe_disable_checks): - tgt_stmt = stmt_cursor._node - assert isinstance(tgt_stmt, LoopIR.stmt) - assert is_pos_int(n_lifts) - - ir, fwd = stmt_cursor.get_root(), lambda x: x +def DoFissionAfterSimple(stmt_cursor, n_lifts_start, unsafe_disable_checks, check_mode): + def static_check(): + n_lifts = n_lifts_start + tgt_stmt = stmt_cursor._node + assert isinstance(tgt_stmt, LoopIR.stmt) + assert is_pos_int(n_lifts) - def alloc_check(pre, post): - if not _is_alloc_free(pre, post): - pre_allocs = {s.name for s in pre if isinstance(s, LoopIR.Alloc)} - post_FV = _FV(post) - for nm in pre_allocs: - if nm in post_FV: - raise SchedulingError( - f"Will not fission here, because " - f"doing so will hide the allocation " - f"of {nm} from a later use site." - ) + ir, fwd = stmt_cursor.get_root(), lambda x: x + + def alloc_check(pre, post): + if not _is_alloc_free(pre, post): + pre_allocs = {s.name for s in pre if isinstance(s, LoopIR.Alloc)} + post_FV = _FV(post) + for nm in pre_allocs: + if nm in post_FV: + raise SchedulingError( + f"Will not fission here, because " + f"doing so will hide the allocation " + f"of {nm} from a later use site." + ) - cur_c = stmt_cursor - while n_lifts > 0: - n_lifts -= 1 + cur_c = stmt_cursor + while n_lifts > 0: + n_lifts -= 1 - idx = cur_c.get_index() + 1 - par_c = cur_c.parent() - par_s = par_c._node + idx = cur_c.get_index() + 1 + par_c = cur_c.parent() + par_s = par_c._node - if isinstance(par_s, LoopIR.For): - pre_c = par_c.body()[:idx] - post_c = par_c.body()[idx:] - elif isinstance(par_s, LoopIR.If): - if cur_c._node in par_s.body: + if isinstance(par_s, LoopIR.For): pre_c = par_c.body()[:idx] post_c = par_c.body()[idx:] + elif isinstance(par_s, LoopIR.If): + if cur_c._node in par_s.body: + pre_c = par_c.body()[:idx] + post_c = par_c.body()[idx:] + else: + pre_c = par_c.orelse()[:idx] + post_c = par_c.orelse()[idx:] else: - pre_c = par_c.orelse()[:idx] - post_c = par_c.orelse()[idx:] - else: - raise SchedulingError("Can only lift past a for loop or an if statement") + raise SchedulingError( + "Can only lift past a for loop or an if statement" + ) - pre = [s._node for s in pre_c] - post = [s._node for s in post_c] + pre = [s._node for s in pre_c] + post = [s._node for s in post_c] - if not (pre and post): - continue + if not (pre and post): + continue - alloc_check(pre, post) + alloc_check(pre, post) - if isinstance(par_s, LoopIR.For): - # we must check whether the two parts of the - # fission can commute appropriately - no_loop_var_pre = par_s.iter not in _FV(pre) - if not unsafe_disable_checks: - Check_FissionLoop(ir, par_s, pre, post, no_loop_var_pre) + if isinstance(par_s, LoopIR.For): + # we must check whether the two parts of the + # fission can commute appropriately + no_loop_var_pre = par_s.iter not in _FV(pre) + if not unsafe_disable_checks: + Check_FissionLoop(ir, par_s, pre, post, no_loop_var_pre) - # we can skip the loop iteration if the - # body doesn't depend on the loop - # and the body is idempotent + # we can skip the loop iteration if the + # body doesn't depend on the loop + # and the body is idempotent - def wrapper(body): - return par_s.update(body=body) + def wrapper(body): + return par_s.update(body=body) - ir, fwd_wrap = post_c._wrap(wrapper, "body") - fwd = _compose(fwd_wrap, fwd) + ir, fwd_wrap = post_c._wrap(wrapper, "body") + fwd = _compose(fwd_wrap, fwd) - post_c = fwd_wrap(par_c).body()[-1] - ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) - fwd = _compose(fwd_move, fwd) + post_c = fwd_wrap(par_c).body()[-1] + ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) + fwd = _compose(fwd_move, fwd) - cur_c = fwd_move(fwd_wrap(par_c)) - elif isinstance(par_s, LoopIR.If): - if cur_c._node in par_s.body: + cur_c = fwd_move(fwd_wrap(par_c)) + elif isinstance(par_s, LoopIR.If): + if cur_c._node in par_s.body: - def wrapper(body): - return par_s.update(body=body, orelse=[]) + def wrapper(body): + return par_s.update(body=body, orelse=[]) - ir, fwd_wrap = pre_c._wrap(wrapper, "body") - fwd = _compose(fwd_wrap, fwd) + ir, fwd_wrap = pre_c._wrap(wrapper, "body") + fwd = _compose(fwd_wrap, fwd) - pre_c = fwd_wrap(par_c).body()[0] - ir, fwd_move = pre_c._move(fwd_wrap(par_c).before()) - fwd = _compose(fwd_move, fwd) + pre_c = fwd_wrap(par_c).body()[0] + ir, fwd_move = pre_c._move(fwd_wrap(par_c).before()) + fwd = _compose(fwd_move, fwd) + + cur_c = fwd_move(fwd_wrap(par_c)).prev() + else: + assert cur_c._node in par_s.orelse + + def wrapper(orelse): + return par_s.update( + body=[LoopIR.Pass(par_s.srcinfo)], orelse=orelse + ) + + ir, fwd_wrap = post_c._wrap(wrapper, "orelse") + fwd = _compose(fwd_wrap, fwd) + + post_c = fwd_wrap(par_c).orelse()[-1] + ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) + fwd = _compose(fwd_move, fwd) - cur_c = fwd_move(fwd_wrap(par_c)).prev() + cur_c = fwd_move(fwd_wrap(par_c)) + + return ir, fwd + + def dynamic_check(): + n_lifts = n_lifts_start + tgt_stmt = stmt_cursor._node + assert isinstance(tgt_stmt, LoopIR.stmt) + assert is_pos_int(n_lifts) + + ir, fwd = stmt_cursor.get_root(), lambda x: x + + def alloc_check(pre, post): + if not _is_alloc_free(pre, post): + pre_allocs = {s.name for s in pre if isinstance(s, LoopIR.Alloc)} + post_FV = _FV(post) + for nm in pre_allocs: + if nm in post_FV: + raise SchedulingError( + f"Will not fission here, because " + f"doing so will hide the allocation " + f"of {nm} from a later use site." + ) + + cur_c = stmt_cursor + while n_lifts > 0: + n_lifts -= 1 + + idx = cur_c.get_index() + 1 + par_c = cur_c.parent() + par_s = par_c._node + + if isinstance(par_s, LoopIR.For): + pre_c = par_c.body()[:idx] + post_c = par_c.body()[idx:] + elif isinstance(par_s, LoopIR.If): + if cur_c._node in par_s.body: + pre_c = par_c.body()[:idx] + post_c = par_c.body()[idx:] + else: + pre_c = par_c.orelse()[:idx] + post_c = par_c.orelse()[idx:] else: - assert cur_c._node in par_s.orelse + raise SchedulingError( + "Can only lift past a for loop or an if statement" + ) - def wrapper(orelse): - return par_s.update( - body=[LoopIR.Pass(par_s.srcinfo)], orelse=orelse - ) + pre = [s._node for s in pre_c] + post = [s._node for s in post_c] - ir, fwd_wrap = post_c._wrap(wrapper, "orelse") + if not (pre and post): + continue + + alloc_check(pre, post) + + if isinstance(par_s, LoopIR.For): + # we must check whether the two parts of the + # fission can commute appropriately + no_loop_var_pre = par_s.iter not in _FV(pre) + + # we can skip the loop iteration if the + # body doesn't depend on the loop + # and the body is idempotent + + def wrapper(body): + return par_s.update(body=body) + + ir, fwd_wrap = post_c._wrap(wrapper, "body") fwd = _compose(fwd_wrap, fwd) - post_c = fwd_wrap(par_c).orelse()[-1] + post_c = fwd_wrap(par_c).body()[-1] ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) fwd = _compose(fwd_move, fwd) cur_c = fwd_move(fwd_wrap(par_c)) + elif isinstance(par_s, LoopIR.If): + if cur_c._node in par_s.body: - return ir, fwd + def wrapper(body): + return par_s.update(body=body, orelse=[]) + + ir, fwd_wrap = pre_c._wrap(wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + + pre_c = fwd_wrap(par_c).body()[0] + ir, fwd_move = pre_c._move(fwd_wrap(par_c).before()) + fwd = _compose(fwd_move, fwd) + + cur_c = fwd_move(fwd_wrap(par_c)).prev() + else: + assert cur_c._node in par_s.orelse + + def wrapper(orelse): + return par_s.update( + body=[LoopIR.Pass(par_s.srcinfo)], orelse=orelse + ) + + ir, fwd_wrap = post_c._wrap(wrapper, "orelse") + fwd = _compose(fwd_wrap, fwd) + + post_c = fwd_wrap(par_c).orelse()[-1] + ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) + fwd = _compose(fwd_move, fwd) + + cur_c = fwd_move(fwd_wrap(par_c)) + + fuzz(stmt_cursor, fwd) + return ir, fwd + + return do_check(static_check, dynamic_check, check_mode) # TODO: Deprecate this with the one above @@ -2688,7 +3684,7 @@ def are_allocs_used_after_block(): return ir, fwd -def DoFuseLoop(f_cursor, s_cursor, unsafe_disable_check=False): +def DoFuseLoop(f_cursor, s_cursor, check_mode: CheckMode, unsafe_disable_check=False): proc = f_cursor.get_root() if f_cursor.next() != s_cursor: @@ -2713,75 +3709,146 @@ def mk_read(e): ir, fwdDel = fwd(s_cursor)._delete() fwd = _compose(fwdDel, fwd) - if not unsafe_disable_check: - x = LoopIR.Read(loop1.iter, [], T.index, loop1.srcinfo) - y = loop2.iter - body1 = loop1.body - body2 = SubstArgs(loop2.body, {y: x}).result() - loop = fwd(f_cursor)._node - Check_FissionLoop(ir, loop, body1, body2) + def static_check(): + if not unsafe_disable_check: + x = LoopIR.Read(loop1.iter, [], T.index, loop1.srcinfo) + y = loop2.iter + body1 = loop1.body + body2 = SubstArgs(loop2.body, {y: x}).result() + loop = fwd(f_cursor)._node + Check_FissionLoop(ir, loop, body1, body2) + + def dynamic_check(): + fuzz(f_cursor.as_block().expand(delta_lo=0, delta_hi=1), fwd) + + do_check(static_check, dynamic_check, check_mode) return ir, fwd -def DoFuseIf(f_cursor, s_cursor): - proc = f_cursor.get_root() - if f_cursor.next() != s_cursor: - raise SchedulingError( - "expected the two if statements to be fused to come one right after the other" - ) +def DoFuseIf(f_cursor, s_cursor, check_mode: CheckMode): + def static_check(): + proc = f_cursor.get_root() + if f_cursor.next() != s_cursor: + raise SchedulingError( + "expected the two if statements to be fused to come one right after the other" + ) - if1 = f_cursor._node - if2 = s_cursor._node - Check_ExprEqvInContext(proc, if1.cond, [if1], if2.cond, [if2]) + if1 = f_cursor._node + if2 = s_cursor._node + Check_ExprEqvInContext(proc, if1.cond, [if1], if2.cond, [if2]) + + cond = if1.cond + body1 = if1.body + body2 = if2.body + orelse1 = if1.orelse + orelse2 = if2.orelse + ifstmt = LoopIR.If(cond, body1 + body2, orelse1 + orelse2, if1.srcinfo) + + ir, fwd = s_cursor.body()._move(f_cursor.body()[-1].after()) + if f_cursor.orelse(): + ir, fwd_move = ( + fwd(s_cursor).orelse()._move(fwd(f_cursor).orelse()[-1].after()) + ) + fwd = _compose(fwd_move, fwd) + else: + ir, fwd_repl = fwd(f_cursor).orelse()._replace(orelse1 + orelse2) + fwd = _compose(fwd_repl, fwd) + ir, fwd_del = fwd(s_cursor)._delete() + fwd = _compose(fwd_del, fwd) + return ir, fwd - cond = if1.cond - body1 = if1.body - body2 = if2.body - orelse1 = if1.orelse - orelse2 = if2.orelse - ifstmt = LoopIR.If(cond, body1 + body2, orelse1 + orelse2, if1.srcinfo) + def dynamic_check(): + proc = f_cursor.get_root() + if f_cursor.next() != s_cursor: + raise SchedulingError( + "expected the two if statements to be fused to come one right after the other" + ) - ir, fwd = s_cursor.body()._move(f_cursor.body()[-1].after()) - if f_cursor.orelse(): - ir, fwd_move = fwd(s_cursor).orelse()._move(fwd(f_cursor).orelse()[-1].after()) - fwd = _compose(fwd_move, fwd) - else: - ir, fwd_repl = fwd(f_cursor).orelse()._replace(orelse1 + orelse2) - fwd = _compose(fwd_repl, fwd) - ir, fwd_del = fwd(s_cursor)._delete() - fwd = _compose(fwd_del, fwd) - return ir, fwd + if1 = f_cursor._node + if2 = s_cursor._node + + cond = if1.cond + body1 = if1.body + body2 = if2.body + orelse1 = if1.orelse + orelse2 = if2.orelse + ifstmt = LoopIR.If(cond, body1 + body2, orelse1 + orelse2, if1.srcinfo) + ir, fwd = s_cursor.body()._move(f_cursor.body()[-1].after()) + if f_cursor.orelse(): + ir, fwd_move = ( + fwd(s_cursor).orelse()._move(fwd(f_cursor).orelse()[-1].after()) + ) + fwd = _compose(fwd_move, fwd) + else: + ir, fwd_repl = fwd(f_cursor).orelse()._replace(orelse1 + orelse2) + fwd = _compose(fwd_repl, fwd) + ir, fwd_del = fwd(s_cursor)._delete() + fwd = _compose(fwd_del, fwd) + fuzz(f_cursor.as_block().expand(delta_lo=0, delta_hi=1), fwd) + return ir, fwd -def DoAddLoop(stmt_cursor, var, hi, guard, unsafe_disable_check): - proc = stmt_cursor.get_root() - s = stmt_cursor._node + return do_check(static_check, dynamic_check, check_mode) - if not unsafe_disable_check: - Check_IsIdempotent(proc, [s]) - Check_IsPositiveExpr(proc, [s], hi) - sym = Sym(var) +def DoAddLoop(stmt_cursor, var, hi, guard, unsafe_disable_check, check_mode: CheckMode): + def static_check(): + proc = stmt_cursor.get_root() + s = stmt_cursor._node - def wrapper(body): - if guard: - rdsym = LoopIR.Read(sym, [], T.index, s.srcinfo) - zero = LoopIR.Const(0, T.int, s.srcinfo) - cond = LoopIR.BinOp("==", rdsym, zero, T.bool, s.srcinfo) - body = [LoopIR.If(cond, body, [], s.srcinfo)] + if not unsafe_disable_check: + Check_IsIdempotent(proc, [s]) + Check_IsPositiveExpr(proc, [s], hi) - return LoopIR.For( - sym, - LoopIR.Const(0, T.index, s.srcinfo), - hi, - body, - LoopIR.Seq(), - s.srcinfo, - ) + sym = Sym(var) - ir, fwd = stmt_cursor.as_block()._wrap(wrapper, "body") - return ir, fwd + def wrapper(body): + if guard: + rdsym = LoopIR.Read(sym, [], T.index, s.srcinfo) + zero = LoopIR.Const(0, T.int, s.srcinfo) + cond = LoopIR.BinOp("==", rdsym, zero, T.bool, s.srcinfo) + body = [LoopIR.If(cond, body, [], s.srcinfo)] + + return LoopIR.For( + sym, + LoopIR.Const(0, T.index, s.srcinfo), + hi, + body, + LoopIR.Seq(), + s.srcinfo, + ) + + ir, fwd = stmt_cursor.as_block()._wrap(wrapper, "body") + return ir, fwd + + def dynamic_check(): + proc = stmt_cursor.get_root() + s = stmt_cursor._node + + sym = Sym(var) + + def wrapper(body): + if guard: + rdsym = LoopIR.Read(sym, [], T.index, s.srcinfo) + zero = LoopIR.Const(0, T.int, s.srcinfo) + cond = LoopIR.BinOp("==", rdsym, zero, T.bool, s.srcinfo) + body = [LoopIR.If(cond, body, [], s.srcinfo)] + + return LoopIR.For( + sym, + LoopIR.Const(0, T.index, s.srcinfo), + hi, + body, + LoopIR.Seq(), + s.srcinfo, + ) + + ir, fwd = stmt_cursor.as_block()._wrap(wrapper, "body") + fuzz(stmt_cursor.parent(), fwd) + return ir, fwd + + return do_check(static_check, dynamic_check, check_mode) # --------------------------------------------------------------------------- # @@ -2843,10 +3910,53 @@ def err_handler(_, msg): return ir, fwd -def DoDeleteConfig(proc_cursor, config_cursor): - eq_mod_config = Check_DeleteConfigWrite(proc_cursor._node, [config_cursor._node]) - p, fwd = config_cursor._delete() - return p, fwd, eq_mod_config +def DoDeleteConfig(proc_cursor, config_cursor, check_mode: CheckMode): + def static_check(): + eq_mod_config = Check_DeleteConfigWrite( + proc_cursor._node, [config_cursor._node] + ) + p, fwd = config_cursor._delete() + return p, fwd, eq_mod_config + + def dynamic_check(): + scope = config_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + p, fwd = config_cursor._delete() + fuzz(scope, fwd) + return p, fwd, None + + return do_check(static_check, dynamic_check, check_mode) + + +def DoDeleteStmt(proc_cursor, stmt_cursor, check_mode: CheckMode): + def static_check(): + assert False, "check must be done with chexo" + + def dynamic_check(): + scope = stmt_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + p, fwd = stmt_cursor._delete() + fuzz(scope, fwd) + return p, fwd + + return do_check(static_check, dynamic_check, check_mode) + + +def DoInsertStmt(proc_cursor, gap_cursor, new_stmt, check_mode: CheckMode): + def static_check(): + assert False, "check must be done with chexo" + + def dynamic_check(): + scope = gap_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + p, fwd = gap_cursor._insert([new_stmt]) + fuzz(scope, fwd) + return p, fwd, None + + return do_check(static_check, dynamic_check, check_mode) def DoDeletePass(proc): @@ -3652,57 +4762,95 @@ def map_s(self, sc): raise NotImplementedError(f"bad case {type(s)}") -def DoEliminateIfDeadBranch(if_cursor): - if_stmt = if_cursor._node +def DoEliminateIfDeadBranch(if_cursor, check_mode: CheckMode): + def static_check(): + if_stmt = if_cursor._node - assert isinstance(if_stmt, LoopIR.If) + assert isinstance(if_stmt, LoopIR.If) - ir, fwd = if_cursor.get_root(), lambda x: x + ir, fwd = if_cursor.get_root(), lambda x: x - try: - cond_node = LoopIR.Const(True, T.bool, if_stmt.srcinfo) - Check_ExprEqvInContext(ir, if_stmt.cond, [if_stmt], cond_node) - cond = True - except SchedulingError: try: - cond_node = LoopIR.Const(False, T.bool, if_stmt.srcinfo) + cond_node = LoopIR.Const(True, T.bool, if_stmt.srcinfo) Check_ExprEqvInContext(ir, if_stmt.cond, [if_stmt], cond_node) - cond = False + cond = True except SchedulingError: - raise SchedulingError("If condition isn't always True or always False") + try: + cond_node = LoopIR.Const(False, T.bool, if_stmt.srcinfo) + Check_ExprEqvInContext(ir, if_stmt.cond, [if_stmt], cond_node) + cond = False + except SchedulingError: + raise SchedulingError("If condition isn't always True or always False") - body = if_cursor.body() if cond else if_cursor.orelse() - ir, fwd = body._move(if_cursor.after()) - ir, fwd_del = fwd(if_cursor)._delete() - fwd = _compose(fwd_del, fwd) + body = if_cursor.body() if cond else if_cursor.orelse() + ir, fwd = body._move(if_cursor.after()) + ir, fwd_del = fwd(if_cursor)._delete() + fwd = _compose(fwd_del, fwd) - return ir, fwd + return ir, fwd + + def dynamic_check(): + if_stmt = if_cursor._node + + assert isinstance(if_stmt, LoopIR.If) + body = if_cursor.body() + ir, fwd = body._move(if_cursor.after()) + ir, fwd_del = fwd(if_cursor)._delete() + fwd = _compose(fwd_del, fwd) -def DoEliminateDeadLoop(loop_cursor): - loop_stmt = loop_cursor._node + try: + fuzz(if_cursor.parent(), fwd) + return ir, fwd + except SchedulingError: + body = if_cursor.orelse() + ir, fwd = body._move(if_cursor.after()) + ir, fwd_del = fwd(if_cursor)._delete() + fwd = _compose(fwd_del, fwd) + fuzz(if_cursor.parent(), fwd) + return ir, fwd - assert isinstance(loop_stmt, LoopIR.For) + return do_check(static_check, dynamic_check, "static") - ir, fwd = loop_cursor.get_root(), lambda x: x - try: - Check_CompareExprs(ir, [loop_stmt], loop_stmt.lo, ">=", loop_stmt.hi) - except SchedulingError: - raise SchedulingError("Loop condition isn't always False") +def DoEliminateDeadLoop(loop_cursor, check_mode: CheckMode): + def static_check(): + loop_stmt = loop_cursor._node + + assert isinstance(loop_stmt, LoopIR.For) + + ir, fwd = loop_cursor.get_root(), lambda x: x + + try: + Check_CompareExprs(ir, [loop_stmt], loop_stmt.lo, ">=", loop_stmt.hi) + except SchedulingError: + raise SchedulingError("Loop condition isn't always False") + + ir, fwd_del = loop_cursor._delete() + + return ir, fwd_del - ir, fwd_del = loop_cursor._delete() + def dynamic_check(): + loop_stmt = loop_cursor._node + + assert isinstance(loop_stmt, LoopIR.For) + + ir, fwd_del = loop_cursor._delete() + + fuzz(loop_cursor.parent(), fwd_del) - return ir, fwd_del + return ir, fwd_del + + return do_check(static_check, dynamic_check, check_mode) -def DoEliminateDeadCode(stmt_cursor): +def DoEliminateDeadCode(stmt_cursor, check_mode: CheckMode): stmt = stmt_cursor._node if isinstance(stmt, LoopIR.If): - return DoEliminateIfDeadBranch(stmt_cursor) + return DoEliminateIfDeadBranch(stmt_cursor, check_mode) elif isinstance(stmt, LoopIR.For): - return DoEliminateDeadLoop(stmt_cursor) + return DoEliminateDeadLoop(stmt_cursor, check_mode) else: assert False, f"Unsupported statement type {type(stmt)}" @@ -3717,7 +4865,7 @@ def DoDeleteBuffer(buf_cursor): return buf_cursor._delete() -def DoReuseBuffer(buf_cursor, rep_cursor): +def DoReuseBuffer(buf_cursor, rep_cursor, check_mode): assert isinstance(buf_cursor._node, LoopIR.Alloc) assert isinstance(rep_cursor._node, LoopIR.Alloc) assert buf_cursor._node.type == rep_cursor._node.type @@ -3727,23 +4875,43 @@ def DoReuseBuffer(buf_cursor, rep_cursor): rep_name = rep_cursor._node.name first_assn = True - ir, fwd = rep_cursor._delete() + def static_check(): + ir, fwd = rep_cursor._delete() - def mk_read(c): - return {"name": buf_name} + def mk_read(c): + return {"name": buf_name} - def mk_write(c): - nonlocal first_assn - if first_assn: - first_assn = False - Check_IsDeadAfter(buf_cursor.get_root(), [c._node], buf_name, buf_dims) - return {"name": buf_name} + def mk_write(c): + nonlocal first_assn + if first_assn: + first_assn = False + Check_IsDeadAfter(buf_cursor.get_root(), [c._node], buf_name, buf_dims) + return {"name": buf_name} - for c in get_rest_of_block(rep_cursor): - ir, fwd = _replace_reads(ir, fwd, c, rep_name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, rep_name, mk_write) + for c in get_rest_of_block(rep_cursor): + ir, fwd = _replace_reads(ir, fwd, c, rep_name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, rep_name, mk_write) + return ir, fwd - return ir, fwd + def dynamic_check(): + ir, fwd = rep_cursor._delete() + + def mk_read(c): + return {"name": buf_name} + + def mk_write(c): + nonlocal first_assn + if first_assn: + first_assn = False + return {"name": buf_name} + + for c in get_rest_of_block(rep_cursor): + ir, fwd = _replace_reads(ir, fwd, c, rep_name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, rep_name, mk_write) + fuzz(get_rest_of_block(rep_cursor), fwd) + return ir, fwd + + return do_check(static_check, dynamic_check, check_mode) def index_range_analysis_wrapper(expr: LoopIR.expr) -> IndexRange: @@ -3879,62 +5047,117 @@ def do_e(self, e): super().do_e(e) -def DoFoldBuffer(alloc_cursor, dim_idx, new_size): +def DoFoldBuffer(alloc_cursor, dim_idx, new_size, check_mode: CheckMode): alloc_name = alloc_cursor._node.name - buffer_check = CheckFoldBuffer(alloc_name, dim_idx, new_size) - buffer_check.do_stmts([c._node for c in get_rest_of_block(alloc_cursor)]) + def static_check(): + buffer_check = CheckFoldBuffer(alloc_name, dim_idx, new_size) + buffer_check.do_stmts([c._node for c in get_rest_of_block(alloc_cursor)]) + + size_expr = LoopIR.Const(new_size, T.index, alloc_cursor._node.srcinfo) + ir, fwd = ( + alloc_cursor._child_node("type") + ._child_block("hi")[dim_idx] + ._replace([size_expr]) + ) - size_expr = LoopIR.Const(new_size, T.index, alloc_cursor._node.srcinfo) - ir, fwd = ( - alloc_cursor._child_node("type") - ._child_block("hi")[dim_idx] - ._replace([size_expr]) - ) + def make_index_mod(e): + return LoopIR.BinOp("%", e, size_expr, T.index, e.srcinfo) - def make_index_mod(e): - return LoopIR.BinOp("%", e, size_expr, T.index, e.srcinfo) + def mk_read(c): + rd = c._node + new_idx = rd.idx.copy() + if isinstance(rd, LoopIR.Read): + new_idx[dim_idx] = make_index_mod(rd.idx[dim_idx]) + return {"idx": new_idx} - def mk_read(c): - rd = c._node - new_idx = rd.idx.copy() - if isinstance(rd, LoopIR.Read): - new_idx[dim_idx] = make_index_mod(rd.idx[dim_idx]) + elif isinstance(rd, LoopIR.WindowExpr): + if isinstance(rd.idx[dim_idx], LoopIR.Point): + new_idx[dim_idx] = LoopIR.Point( + make_index_mod(rd.idx[dim_idx].pt), rd.srcinfo + ) + else: + # TODO: see if check_bounds catches the case where lo, hi spans a multiple + # of size, which would break the buffer folding + new_idx[dim_idx] = LoopIR.Interval( + make_index_mod(rd.idx[dim_idx].lo), + make_index_mod(rd.idx[dim_idx].hi), + rd.srcinfo, + ) + + return {"idx": new_idx} + else: + raise NotImplementedError(f"Did not implement {type(rd)}.") + + def mk_write(c): + s = c._node + new_idx = s.idx.copy() + new_idx[dim_idx] = make_index_mod(s.idx[dim_idx]) return {"idx": new_idx} - elif isinstance(rd, LoopIR.WindowExpr): - if isinstance(rd.idx[dim_idx], LoopIR.Point): - new_idx[dim_idx] = LoopIR.Point( - make_index_mod(rd.idx[dim_idx].pt), rd.srcinfo - ) + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) + + new_alloc_cursor = fwd(alloc_cursor) + after_alloc = [c._node for c in get_rest_of_block(new_alloc_cursor)] + + Check_Bounds(ir, new_alloc_cursor._node, after_alloc) + return ir, fwd + + def dynamic_check(): + size_expr = LoopIR.Const(new_size, T.index, alloc_cursor._node.srcinfo) + ir, fwd = ( + alloc_cursor._child_node("type") + ._child_block("hi")[dim_idx] + ._replace([size_expr]) + ) + + def make_index_mod(e): + return LoopIR.BinOp("%", e, size_expr, T.index, e.srcinfo) + + def mk_read(c): + rd = c._node + new_idx = rd.idx.copy() + if isinstance(rd, LoopIR.Read): + new_idx[dim_idx] = make_index_mod(rd.idx[dim_idx]) + return {"idx": new_idx} + + elif isinstance(rd, LoopIR.WindowExpr): + if isinstance(rd.idx[dim_idx], LoopIR.Point): + new_idx[dim_idx] = LoopIR.Point( + make_index_mod(rd.idx[dim_idx].pt), rd.srcinfo + ) + else: + # TODO: see if check_bounds catches the case where lo, hi spans a multiple + # of size, which would break the buffer folding + new_idx[dim_idx] = LoopIR.Interval( + make_index_mod(rd.idx[dim_idx].lo), + make_index_mod(rd.idx[dim_idx].hi), + rd.srcinfo, + ) + + return {"idx": new_idx} else: - # TODO: see if check_bounds catches the case where lo, hi spans a multiple - # of size, which would break the buffer folding - new_idx[dim_idx] = LoopIR.Interval( - make_index_mod(rd.idx[dim_idx].lo), - make_index_mod(rd.idx[dim_idx].hi), - rd.srcinfo, - ) + raise NotImplementedError(f"Did not implement {type(rd)}.") + def mk_write(c): + s = c._node + new_idx = s.idx.copy() + new_idx[dim_idx] = make_index_mod(s.idx[dim_idx]) return {"idx": new_idx} - else: - raise NotImplementedError(f"Did not implement {type(rd)}.") - def mk_write(c): - s = c._node - new_idx = s.idx.copy() - new_idx[dim_idx] = make_index_mod(s.idx[dim_idx]) - return {"idx": new_idx} + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) + new_alloc_cursor = fwd(alloc_cursor) + after_alloc = [c._node for c in get_rest_of_block(new_alloc_cursor)] - alloc_cursor = fwd(alloc_cursor) - after_alloc = [c._node for c in get_rest_of_block(alloc_cursor)] - Check_Bounds(ir, alloc_cursor._node, after_alloc) + fuzz(get_rest_of_block(alloc_cursor), fwd) + return ir, fwd - return ir, fwd + return do_check(static_check, dynamic_check, check_mode) def DoStageMem(block_cursor, buf_name, w_exprs, new_name, use_accum_zero=False): diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index ceb39c4db..c19fdc088 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -1,4 +1,5 @@ from itertools import chain +import time from typing import Callable, Literal, Optional, Union from ..core.internal_cursors import Cursor, Block, Node, NodePath @@ -78,7 +79,7 @@ def visit(self, node): else: self.type_map[arg.name] = arg.type for stmt in node.f.body: - self.visit_generic(stmt) + self.visit(stmt) else: self.visit_generic(node) @@ -228,8 +229,13 @@ def collect_path_constraints( cursor: Union[Block, Node], cm: ConstraintMaker, type_map: dict[Sym, LoopIR.type] ) -> DisjointConstraint: if isinstance(cursor, Block): - cursor = cursor[0] + if len(cursor) > 0: + cursor = cursor[0] + else: + cursor = cursor._anchor assert isinstance(cursor, Node) + if len(cursor._path) == 0: + return TRUE_CONSTRAINT last_attr, last_index = cursor._path[-1] cur = cursor.parent() result = TRUE_CONSTRAINT @@ -255,7 +261,7 @@ def collect_path_constraints( elif isinstance(cur._node, LoopIR.If): assert last_index is not None modified_variable_visitor = ModifiedVariableVisitor(type_map) - for stmt, _ in zip(cur._node[last_attr], range(last_index)): + for stmt, _ in zip(getattr(cur._node, last_attr), range(last_index)): modified_variable_visitor.visit(stmt) for var_sym in modified_variable_visitor.modified_vars: var_renaming[var_sym] = cm.copy_var(var_sym) @@ -520,6 +526,8 @@ def broaden(self) -> Optional["TestScope"]: return TestScope(self.scope._anchor.as_block()) def transform(self, forward: Callable[[Cursor], Cursor]) -> "TestScope": + if self.broaden() is None: + return TestScope(forward(self.scope._anchor)._child_block("body")) return TestScope(forward(self.scope)) def get_type_map(self) -> dict[Sym, LoopIR.type]: @@ -582,9 +590,90 @@ def get_test_spec( TEST_CASE_BOUND = 15 +MAX_FAILS = 3 +MAX_ITERS = 20 +TIME_RECORDING_FILE = None # "./times.csv" + + +@dataclass +class Timer: + fuzz_start: Optional[int] = None + transpile_start: Optional[int] = None + constraint_start: Optional[int] = None + test_start: Optional[int] = None + scope_widen_count: int = 0 + fuzz_total: int = 0 + transpile_total: int = 0 + constraint_total: int = 0 + test_total: int = 0 + + def start_fuzz(self): + if TIME_RECORDING_FILE is None: + return + self.fuzz_start = time.process_time_ns() + + def widen_scope(self): + self.scope_widen_count += 1 + + def end_fuzz(self): + if TIME_RECORDING_FILE is None: + return + assert self.fuzz_start is not None + self.fuzz_total += time.process_time_ns() - self.fuzz_start + self.fuzz_start = None + + def start_transpile(self): + if TIME_RECORDING_FILE is None: + return + self.transpile_start = time.process_time_ns() + + def end_transpile(self): + if TIME_RECORDING_FILE is None: + return + assert self.transpile_start is not None + self.transpile_total += time.process_time_ns() - self.transpile_start + self.transpile_start = None + + def start_constraint(self): + if TIME_RECORDING_FILE is None: + return + self.constraint_start = time.process_time_ns() + + def end_constraint(self): + if TIME_RECORDING_FILE is None: + return + assert self.constraint_start is not None + self.constraint_total += time.process_time_ns() - self.constraint_start + self.constraint_start = None + + def start_test(self): + if TIME_RECORDING_FILE is None: + return + self.test_start = time.process_time_ns() + + def end_test(self): + if TIME_RECORDING_FILE is None: + return + assert self.test_start is not None + self.test_total += time.process_time_ns() - self.test_start + self.test_start = None + + def record(self, failed: bool, unsolved: bool): + if TIME_RECORDING_FILE is None: + return + with open(TIME_RECORDING_FILE, "a") as recording: + recording.write( + f"{self.fuzz_total},{self.transpile_total},{self.constraint_total},{self.test_total},{self.scope_widen_count},{1 if failed else 0},{1 if unsolved else 0}\n" + ) def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): + timer = Timer() + timer.start_fuzz() + starting_scope = Cursor.create(starting_scope.get_root()) + if isinstance(starting_scope, Node) and starting_scope.depth() == 0: + starting_scope = starting_scope.body() + starting_scope = ( starting_scope.as_block() if isinstance(starting_scope, Node) @@ -598,6 +687,7 @@ def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): transformed_type_map = cur_scope.transform(fwd).get_type_map() while cur_scope is not None: + timer.start_transpile() transformed = cur_scope.transform(fwd) cm = ConstraintMaker(cur_type_map | transformed_type_map) @@ -626,7 +716,15 @@ def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): assert skeleton1 is not None and skeleton2 is not None coverage_skeleton = skeleton1.merge(skeleton2) tests_passed = True - while not coverage_skeleton.get_coverage_progress().is_finished(): + fails = 0 + iters = 0 + timer.end_transpile() + while ( + not coverage_skeleton.get_coverage_progress().is_finished() + and iters < MAX_ITERS + and tests_passed + ): + timer.start_constraint() test_case = generate_test_case( arg_types, config_fields, @@ -634,13 +732,22 @@ def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): coverage_skeleton, cm, ) + timer.end_constraint() if test_case is None: - continue + fails += 1 + if fails > MAX_FAILS: + timer.end_fuzz() + timer.record(False, True) + return + else: + continue + timer.start_test() out1 = run_test_case(test_case, transpiled_test1) out2 = run_test_case(test_case, transpiled_test2) if out1 == "failed" or out2 == "failed": tests_passed = False + timer.end_test() break assert out1.coverage_result is not None and out2.coverage_result is not None coverage_skeleton.update_coverage( @@ -659,10 +766,17 @@ def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): ): tests_passed = False break + timer.end_test() + iters += 1 if tests_passed: + timer.end_fuzz() + timer.record(False, False) return else: + timer.widen_scope() cur_scope = cur_scope.broaden() + timer.end_fuzz() + timer.record(True, False) raise SchedulingError("tests failed at broadest scope") diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/constraint_solver.py index 444531f68..bc106207f 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/constraint_solver.py @@ -389,21 +389,32 @@ def rename_syms(self, lookup: dict[Sym, Sym]) -> "ConstraintClause": ) +MAX_CLAUSES = 16 + + @dataclass class DisjointConstraint: clauses: tuple[ConstraintClause, ...] def intersect(self, other: "DisjointConstraint"): - return DisjointConstraint( - tuple( - ConstraintClause(lhs_clause.constraints + rhs_clause.constraints) - for lhs_clause in self.clauses - for rhs_clause in other.clauses - ) + new_clauses = tuple( + ConstraintClause(lhs_clause.constraints + rhs_clause.constraints) + for lhs_clause in self.clauses + for rhs_clause in other.clauses ) + if len(new_clauses) > MAX_CLAUSES: + new_clauses = tuple( + np.random.choice(new_clauses, MAX_CLAUSES, replace=False) + ) + return DisjointConstraint(new_clauses) def union(self, other: "DisjointConstraint"): - return DisjointConstraint(self.clauses + other.clauses) + new_clauses = self.clauses + other.clauses + if len(new_clauses) > MAX_CLAUSES: + new_clauses = tuple( + np.random.choice(new_clauses, MAX_CLAUSES, replace=False) + ) + return DisjointConstraint(new_clauses) def invert(self) -> "DisjointConstraint": acc = TRUE_CONSTRAINT @@ -461,6 +472,9 @@ class Solution: substitutions: dict[Sym, int] +SIMUL_CONSTRAINT_LIMIT = 32 + + class ConstraintMaker: def __init__(self, type_map: dict[Sym, LoopIR.type]): self.var_subs: dict[Sym, Expression] = {} @@ -500,7 +514,7 @@ def make_var_sub(self, name: str, var_type: LoopIR.type) -> Optional[Expression] return Expression.from_sym(Sym(f"{name}_m1")).add( Expression.from_constant(1) ) - elif isinstance(var_type, (T.Int, T.Index)): + elif isinstance(var_type, (T.Int, T.Index, T.INT32, T.INT8, T.UINT16, T.UINT8)): # unsigned variables are represented as a - b, where a and b are nonnegative a, b = Sym(f"{name}_a"), Sym(f"{name}_b") return Expression.from_sym(a).add(Expression.from_sym(b).negate()) @@ -927,6 +941,7 @@ def solve_constraint( clause for clause in disjoint_constraint.clauses if all(not constraint.is_unsolvable() for constraint in clause.constraints) + and len(clause.constraints) <= SIMUL_CONSTRAINT_LIMIT ) for _ in range(search_limit): if len(clauses) == 0: diff --git a/src/exo/stdlib/scheduling.py b/src/exo/stdlib/scheduling.py index 9e30eb177..26b867e78 100644 --- a/src/exo/stdlib/scheduling.py +++ b/src/exo/stdlib/scheduling.py @@ -29,6 +29,8 @@ bind_expr, commute_expr, left_reassociate_expr, + insert_mutate, + delete_stmt, # # subprocedure oriented operations extract_subproc, @@ -65,6 +67,7 @@ parallelize_loop, divide_with_recompute, divide_loop, + divide_loop_min, mult_loops, cut_loop, join_loops, From 1e8b5be7e38ac1ed79bfe7570565ceae7985b512 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Mon, 26 May 2025 15:48:02 -0400 Subject: [PATCH 19/24] clean up --- src/exo/API.py | 2 +- src/exo/API_scheduling.py | 28 - src/exo/backend/LoopIR_transpiler.py | 100 ++- src/exo/backend/coverage.py | 174 ++-- src/exo/frontend/typecheck.py | 5 +- src/exo/rewrite/LoopIR_scheduling.py | 1149 +++++++------------------- src/exo/rewrite/chexo.py | 196 ++--- 7 files changed, 555 insertions(+), 1099 deletions(-) diff --git a/src/exo/API.py b/src/exo/API.py index ac5b19503..557d4f042 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -192,7 +192,7 @@ def __init__( proc = TypeChecker(proc, self._check_mode).get_loopir() if self._check_mode != "dynamic": CheckBounds(proc) - Check_Aliasing(proc) + Check_Aliasing(proc) assert isinstance(proc, LoopIR.LoopIR.proc) diff --git a/src/exo/API_scheduling.py b/src/exo/API_scheduling.py index 5db2d8845..e71d7e031 100644 --- a/src/exo/API_scheduling.py +++ b/src/exo/API_scheduling.py @@ -720,34 +720,6 @@ def __call__(self, expr_str, all_args): return expr -class NewStmtA(ArgumentProcessor): - def __init__(self, cursor_arg, before=True): - self.cursor_arg = cursor_arg - self.before = before - - def _get_ctxt_stmt(self, all_args): - cursor = all_args[self.cursor_arg] - while isinstance(cursor, PC.ExprCursor): - cursor = cursor.parent() - - # if we don't have a gap cursor, convert to a gap cursor - if not isinstance(cursor, PC.GapCursor): - cursor = cursor.before() if self.before else cursor.after() - - # TODO: improve parse_fragment to just take gaps - return cursor.anchor()._impl._node - - def __call__(self, stmt_str, all_args): - if not isinstance(stmt_str, str): - self.err("expected a string") - - proc = all_args["proc"] - ctxt_stmt = self._get_ctxt_stmt(all_args) - - expr = parse_fragment(proc._loopir_proc, stmt_str, ctxt_stmt) - return expr - - # This is implemented as a workaround because the # current PAST parser and PAST IR don't support windowing # expressions. diff --git a/src/exo/backend/LoopIR_transpiler.py b/src/exo/backend/LoopIR_transpiler.py index 5544becb7..3d9c427a4 100644 --- a/src/exo/backend/LoopIR_transpiler.py +++ b/src/exo/backend/LoopIR_transpiler.py @@ -17,6 +17,7 @@ MemoryAccessPair, ParallelAccess, ParallelAccessPair, + StagingBoundCheck, SymbolicPoint, SymbolicSlice, StagingOverlap, @@ -260,32 +261,104 @@ class SymbolicWindow: index: tuple[SymbolicWindowIndex, ...] +@dataclass +class StagedWindowExpr: + indices: tuple[Union[tuple[LoopIR.expr, LoopIR.expr], LoopIR.expr], ...] + + @dataclass class StageMemArgs: - window_expr: LoopIR.WindowExpr + buffer_sym: Sym + staged_window_expr: StagedWindowExpr scope: Block class StageMemTracker: def __init__(self, args: StageMemArgs, parent_state: "CoverageState"): self.scope: Block = args.scope - self.buffer_sym = args.window_expr.name + self.buffer_sym = args.buffer_sym + self.staged_window_expr = args.staged_window_expr self.staged_window: Optional[tuple[SymbolicTensor, Tensor]] = None + self.enabled: bool = False self.overlaps: list[StagingOverlap] = [] + self.bound_checks: list[StagingBoundCheck] = [] self.parent_state: "CoverageState" = parent_state def enter_stmt(self, stmt_node: Node): if stmt_node in self.scope: - js_tensor = self.parent_state.parent_transpiler._lookup_sym(self.buffer_sym) - assert isinstance(js_tensor, Tensor) - self.staged_window = ( - self.parent_state.symbolic_tensors[self.buffer_sym], - js_tensor, - ) + self.enabled = True + + if self.staged_window is None: + staged_win_sym = Sym("win") + js_parent = self.parent_state.parent_transpiler._lookup_sym( + self.buffer_sym + ) + js_staged = self.parent_state.parent_transpiler._transpile_window( + staged_win_sym, + self.buffer_sym, + Cursor.create(self.staged_window_expr)._child_block("indices"), + ) + assert isinstance(js_parent, Tensor) + symbolic_parent = self.parent_state.symbolic_tensors[self.buffer_sym] + symbolic_staged = self.parent_state.symbolic_tensors[staged_win_sym] + stage_placeholder = ( + self.parent_state.parent_transpiler._make_placeholder() + ) + for ( + symbolic_parent_dim, + symbolic_staged_dim, + js_parent_dim, + js_staged_dim, + ) in zip( + symbolic_parent.dims, + symbolic_staged.dims, + (dim.window_idx for dim in js_parent.dims), + (dim.window_idx for dim in js_staged.dims), + ): + if isinstance(symbolic_parent_dim, SymbolicSlice): + assert isinstance(js_parent_dim, Slice) + out_of_bounds_sym = Sym("oob") + js_out_of_bounds_cond = ( + "&&".join( + ( + f"({js_staged_dim.index}<{js_parent_dim.upper_bound})", + f"({js_parent_dim.lower_bound}<={js_staged_dim.index})", + ) + ) + if isinstance(js_staged_dim, Point) + else "&&".join( + ( + f"({js_parent_dim.lower_bound}<={js_staged_dim.lower_bound})", + f"({js_staged_dim.upper_bound})<{js_parent_dim.upper_bound})", + ) + ) + ) + self.bound_checks.append( + StagingBoundCheck( + out_of_bounds_sym, + symbolic_staged_dim, + symbolic_parent_dim, + self.parent_state.current_node, + ( + IndexedFiller( + self.parent_state.cov_placeholder, + f"let {repr(out_of_bounds_sym)}=false;", + ), + IndexedFiller( + stage_placeholder, + f"if({js_out_of_bounds_cond}){{let {repr(out_of_bounds_sym)}=true;}}", + ), + ), + ) + ) + self.staged_window = ( + symbolic_staged, + js_staged, + ) def exit_stmt(self, stmt_cursor: Node): if stmt_cursor in self.scope: - self.staged_window = None + self.enabled = False def access_tensor( self, @@ -299,6 +372,7 @@ def access_tensor( ): if ( self.staged_window is not None + and self.enabled and self.staged_window[1].name == js_tensor.name ): symbolic_staged_window, js_staged_window = self.staged_window @@ -340,6 +414,9 @@ def access_tensor( def make_staging_overlaps(self) -> tuple[StagingOverlap, ...]: return tuple(self.overlaps) + def make_staging_bound_checks(self) -> tuple[StagingBoundCheck, ...]: + return tuple(self.bound_checks) + @dataclass class ParallelScope: @@ -861,6 +938,11 @@ def make_skeleton(self) -> CoverageSkeleton: if self.stage_mem_tracker is None else self.stage_mem_tracker.make_staging_overlaps() ), + ( + () + if self.stage_mem_tracker is None + else self.stage_mem_tracker.make_staging_bound_checks() + ), self.parallel_access_tracker.make_parallel_access_pairs(), frozenset(self.free_vars), ) diff --git a/src/exo/backend/coverage.py b/src/exo/backend/coverage.py index 5ab8f035f..350154655 100644 --- a/src/exo/backend/coverage.py +++ b/src/exo/backend/coverage.py @@ -404,6 +404,92 @@ def rename_syms(self, lookup: dict[Sym, Sym]) -> "SymbolicSlice": SymbolicWindowIndex = Union[SymbolicPoint, SymbolicSlice] +@dataclass +class StagingBoundCheck: + out_of_bounds_sym: Sym + staged_index: SymbolicWindowIndex + parent_index: SymbolicSlice + node: CoverageSkeletonNode + indexed_fillers: tuple[IndexedFiller, ...] + had_out_of_bounds: bool = False + visited_out_of_bounds: bool = False + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset((self.out_of_bounds_sym,)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + out_of_bounds = coverage_result[repr(self.out_of_bounds_sym)] + assert isinstance(out_of_bounds, bool) + self.had_out_of_bounds |= out_of_bounds + self.visited_out_of_bounds |= out_of_bounds + + def get_coverage_progress(self) -> CoverageProgress: + return CoverageProgress( + (1 if self.visited_out_of_bounds else 0), + 1, + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + if not self.visited_out_of_bounds: + out_of_bounds_cond = ( + Constraint( + self.staged_index.index.add(Expression.from_constant(1)) + .negate() + .add(self.parent_index.upper_bound), + True, + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + self.staged_index.index.add( + self.parent_index.lower_bound.negate() + ), + True, + ).lift_to_disjoint_constraint() + ) + if isinstance(self.staged_index, SymbolicPoint) + else Constraint( + self.staged_index.upper_bound.negate().add( + self.parent_index.upper_bound + ), + True, + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + self.staged_index.lower_bound.add( + self.parent_index.lower_bound.negate() + ), + True, + ).lift_to_disjoint_constraint() + ) + ) + path_constraint = self.node.get_complete_constraint().intersect( + out_of_bounds_cond + ) + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms(), + state.free_vars, + ) + new_constraint = state.current_constraint.intersect( + path_constraint.rename_syms(sym_renaming) + ) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + self.visited_out_of_bounds = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + return state + + # for stage_mem @dataclass class StagingOverlap: @@ -668,12 +754,23 @@ def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: return state +CoverageTask = Union[ + CoverageSkeletonNode, + MemoryAccessPair, + FailureCondition, + StagingOverlap, + StagingBoundCheck, + ParallelAccessPair, +] + + @dataclass class CoverageSkeleton: roots: tuple[CoverageSkeletonNode, ...] aliasable_accesses: tuple[MemoryAccessPair, ...] failure_conditions: tuple[FailureCondition, ...] staging_overlaps: tuple[StagingOverlap, ...] + staging_bound_checks: tuple[StagingBoundCheck, ...] parallel_accesses: tuple[ParallelAccessPair, ...] free_vars: frozenset[Sym] @@ -683,67 +780,38 @@ def merge(self, other: "CoverageSkeleton") -> "CoverageSkeleton": self.aliasable_accesses + other.aliasable_accesses, self.failure_conditions + other.failure_conditions, self.staging_overlaps + other.staging_overlaps, + self.staging_bound_checks + other.staging_bound_checks, self.parallel_accesses + other.parallel_accesses, self.free_vars | other.free_vars, ) + def get_coverage_tasks(self) -> tuple[CoverageTask, ...]: + return ( + self.parallel_accesses + + self.staging_bound_checks + + self.staging_overlaps + + self.failure_conditions + + self.aliasable_accesses + + self.roots + ) + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: - for root in self.roots: - yield from root.get_indexed_fillers() - for aliasable_access in self.aliasable_accesses: - yield from aliasable_access.get_indexed_fillers() - for failure_condition in self.failure_conditions: - yield from failure_condition.get_indexed_fillers() - for staging_overlap in self.staging_overlaps: - yield from staging_overlap.get_indexed_fillers() - for parallel_access in self.parallel_accesses: - yield from parallel_access.get_indexed_fillers() + for task in self.get_coverage_tasks(): + yield from task.get_indexed_fillers() def get_coverage_syms(self) -> frozenset[Sym]: return frozenset().union( - *tuple(root_node.get_coverage_syms() for root_node in self.roots), - *tuple( - aliasable_access.get_coverage_syms() - for aliasable_access in self.aliasable_accesses - ), - *tuple( - failure_condition.get_coverage_syms() - for failure_condition in self.failure_conditions - ), - *tuple( - staging_overlap.get_coverage_syms() - for staging_overlap in self.staging_overlaps - ), - *tuple( - parallel_access.get_coverage_syms() - for parallel_access in self.parallel_accesses - ), + *tuple(task.get_coverage_syms() for task in self.get_coverage_tasks()), ) def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): - for root_node in self.roots: - root_node.update_coverage(coverage_result) - for aliasable_access in self.aliasable_accesses: - aliasable_access.update_coverage(coverage_result) - for failure_condition in self.failure_conditions: - failure_condition.update_coverage(coverage_result) - for staging_overlap in self.staging_overlaps: - staging_overlap.update_coverage(coverage_result) - for parallel_access in self.parallel_accesses: - parallel_access.update_coverage(coverage_result) + for task in reversed(self.get_coverage_tasks()): + task.update_coverage(coverage_result) def get_coverage_progress(self) -> CoverageProgress: result = CoverageProgress(0, 0) - for root_node in self.roots: - result = root_node.get_coverage_progress() - for aliasable_access in self.aliasable_accesses: - result = result.merge(aliasable_access.get_coverage_progress()) - for failure_condition in self.failure_conditions: - result = result.merge(failure_condition.get_coverage_progress()) - for staging_overlap in self.staging_overlaps: - result = result.merge(staging_overlap.get_coverage_progress()) - for parallel_access in self.parallel_accesses: - result = result.merge(parallel_access.get_coverage_progress()) + for task in self.get_coverage_tasks(): + result = task.get_coverage_progress() return result def solve_constraint_with_coverage( @@ -768,14 +836,6 @@ def solve_constraint_with_coverage( bound, search_limit, ) - for parallel_access in self.parallel_accesses: - state = parallel_access.solve_coverage(state) - for staging_overlap in self.staging_overlaps: - state = staging_overlap.solve_coverage(state) - for failure_condition in self.failure_conditions: - state = failure_condition.solve_coverage(state) - for aliasable_access in self.aliasable_accesses: - state = aliasable_access.solve_coverage(state) - for root_node in self.roots: - state = root_node.solve_coverage(state) + for task in self.get_coverage_tasks(): + state = task.solve_coverage(state) return state.current_solution diff --git a/src/exo/frontend/typecheck.py b/src/exo/frontend/typecheck.py index 1518a88cc..9bd237694 100644 --- a/src/exo/frontend/typecheck.py +++ b/src/exo/frontend/typecheck.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Union from ..core.LoopIR import ( T, UAST, @@ -97,7 +97,8 @@ def check_call_types(err_handler, args, call_args): assert False, "bad argument type case" -CheckMode = Literal["static", "dynamic", "both"] +Checker = Literal["static", "dynamic"] +CheckMode = Union[Checker, Literal["both"]] class TypeChecker: diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index f5b5650f0..22a4e1561 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -1,7 +1,17 @@ import re from collections import ChainMap import traceback -from typing import Any, Callable, Generator, List, Literal, Tuple, Optional, Union +from typing import ( + Any, + Callable, + Generator, + List, + Literal, + Tuple, + Optional, + TypeVar, + Union, +) from ..core.LoopIR import ( LoopIR, @@ -33,7 +43,7 @@ Check_ExprBound, Check_Aliasing, ) -from .chexo import fuzz, fuzz_reorder_stmts +from .chexo import fuzz from .range_analysis import IndexRangeEnvironment, IndexRange, index_range_analysis from ..core.internal_cursors import Block, Node @@ -44,7 +54,7 @@ import exo.API as api from ..frontend.pattern_match import match_pattern from ..core.memory import DRAM -from ..frontend.typecheck import check_call_types, CheckMode +from ..frontend.typecheck import Checker, check_call_types, CheckMode from ..libs.externs import intmin from functools import partial @@ -424,23 +434,24 @@ def divide_expr(e, quot): # --------------------------------------------------------------------------- # # Scheduling directives +T = TypeVar("T") + def do_check( - static_check: Callable[[], Any], - dynamic_check: Callable[[], Any], + check: Callable[[Checker], T], mode: CheckMode, -) -> Any: +) -> T: if mode == "both": e_static, e_dynamic = None, None trb_static, trb_dynamic = None, None static_res = None try: - static_res = static_check() + static_res = check("static") except Exception as e: e_static = e trb_static = traceback.format_exc() try: - dynamic_check() + check("dynamic") except Exception as e: e_dynamic = e trb_dynamic = traceback.format_exc() @@ -452,44 +463,46 @@ def do_check( raise e_static else: return static_res - elif mode == "static": - return static_check() - elif mode == "dynamic": - return dynamic_check() + else: + return check(mode) # Take a conservative approach and allow stmt reordering only when they are # writing to different buffers # TODO: Do effectcheck's check_commutes-ish thing using SMT here def DoReorderStmt(f_cursor, s_cursor, check_mode: CheckMode): - if f_cursor.next() != s_cursor: - raise SchedulingError( - "expected the second statement to be directly after the first" - ) - do_check( - lambda: Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node), - lambda: fuzz_reorder_stmts(f_cursor, s_cursor), + def check(checker: Checker): + if f_cursor.next() != s_cursor: + raise SchedulingError( + "expected the second statement to be directly after the first" + ) + ir, fwd = s_cursor._move(f_cursor.before()) + if checker == "dynamic": + fuzz( + f_cursor.as_block().expand(0, 1), fwd(s_cursor).as_block().expand(0, 1) + ) + else: + Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node) + return ir, fwd + + return do_check( + check, check_mode, ) - ir, fwd = s_cursor._move(f_cursor.before()) - return ir, fwd def DoParallelizeLoop(loop_cursor, check_mode: CheckMode): - ir, fwd = loop_cursor._child_node("loop_mode")._replace(LoopIR.Par()) - - def static_check(): - pass - - def dynamic_check(): - fuzz(loop_cursor, fwd) + def check(checker: Checker): + ir, fwd = loop_cursor._child_node("loop_mode")._replace(LoopIR.Par()) + if checker == "dynamic": + fuzz(loop_cursor.as_block(), fwd(loop_cursor).as_block()) + return ir, fwd - do_check(static_check, dynamic_check, check_mode) - return ir, fwd + return do_check(check, check_mode) def DoJoinLoops(loop1_c, loop2_c, check_mode: CheckMode): - def static_check(): + def check(checker: Checker): if loop1_c.next() != loop2_c: raise SchedulingError( "expected the second loop to be directly after the first" @@ -498,14 +511,15 @@ def static_check(): loop1 = loop1_c._node loop2 = loop2_c._node - try: - Check_ExprEqvInContext( - loop1_c.get_root(), loop1.hi, [loop1], loop2.lo, [loop2] - ) - except Exception as e: - raise SchedulingError( - f"expected the first loop upper bound {loop1.hi} to be the same as the second loop lower bound {loop2.lo}" - ) + if checker == "static": + try: + Check_ExprEqvInContext( + loop1_c.get_root(), loop1.hi, [loop1], loop2.lo, [loop2] + ) + except Exception as e: + raise SchedulingError( + f"expected the first loop upper bound {loop1.hi} to be the same as the second loop lower bound {loop2.lo}" + ) compare_ir = LoopIR_Compare() if not compare_ir.match_stmts(loop1.body, loop2.body): @@ -514,87 +528,61 @@ def static_check(): ir, fwd = loop1_c._child_node("hi")._replace(loop2.hi) ir, fwd_del = fwd(loop2_c)._delete() - return ir, _compose(fwd_del, fwd) - - def dynamic_check(): - if loop1_c.next() != loop2_c: - raise SchedulingError( - "expected the second loop to be directly after the first" - ) - - loop1 = loop1_c._node - loop2 = loop2_c._node - - compare_ir = LoopIR_Compare() - if not compare_ir.match_stmts(loop1.body, loop2.body): - raise SchedulingError("expected the two loops to have identical bodies") - - ir, fwd = loop1_c._child_node("hi")._replace(loop2.hi) - ir, fwd_del = fwd(loop2_c)._delete() - fuzz(loop1_c.as_block().expand(delta_lo=0, delta_hi=1), fwd) + if checker == "dynamic": + fuzz(loop1_c.as_block().expand(0, 1), fwd_del(fwd(loop1_c)).as_block()) return ir, _compose(fwd_del, fwd) - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoCutLoop(loop_c, cut_point, check_mode: CheckMode): - def static_check(): + def check(checker: Checker): s = loop_c._node assert isinstance(s, LoopIR.For) ir = loop_c.get_root() - try: - Check_CompareExprs(ir, [s], cut_point, ">=", s.lo) - except SchedulingError: - raise SchedulingError(f"Expected `lo` <= `cut_point`") + if checker == "static": + try: + Check_CompareExprs(ir, [s], cut_point, ">=", s.lo) + except SchedulingError: + raise SchedulingError(f"Expected `lo` <= `cut_point`") - try: - Check_CompareExprs(ir, [s], s.hi, ">=", cut_point) - except SchedulingError: - raise SchedulingError(f"Expected `cut_point` <= `hi`") + try: + Check_CompareExprs(ir, [s], s.hi, ">=", cut_point) + except SchedulingError: + raise SchedulingError(f"Expected `cut_point` <= `hi`") ir, fwd1 = loop_c._child_node("hi")._replace(cut_point) loop2 = Alpha_Rename([s.update(lo=cut_point)]).result()[0] ir, fwd2 = fwd1(loop_c).after()._insert([loop2]) fwd = _compose(fwd2, fwd1) - return ir, fwd - - def dynamic_check(): - s = loop_c._node - - assert isinstance(s, LoopIR.For) - - ir = loop_c.get_root() - - ir, fwd1 = loop_c._child_node("hi")._replace(cut_point) - loop2 = Alpha_Rename([s.update(lo=cut_point)]).result()[0] - ir, fwd2 = fwd1(loop_c).after()._insert([loop2]) - fwd = _compose(fwd2, fwd1) - fuzz(loop_c.parent(), fwd) + if checker == "dynamic": + fuzz(loop_c.as_block(), fwd(loop_c).as_block().expand(0, 1)) return ir, fwd - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoShiftLoop(loop_c, new_lo, check_mode: CheckMode): - def static_check(): + def check(checker: Checker): s = loop_c._node assert isinstance(s, LoopIR.For) - try: - Check_IsNonNegativeExpr( - loop_c.get_root(), - [s], - new_lo, - ) - except SchedulingError: - raise SchedulingError(f"Expected 0 <= `new_lo`") + if checker == "static": + try: + Check_IsNonNegativeExpr( + loop_c.get_root(), + [s], + new_lo, + ) + except SchedulingError: + raise SchedulingError(f"Expected 0 <= `new_lo`") loop_length = LoopIR.BinOp("-", s.hi, s.lo, T.index, s.srcinfo) new_hi = LoopIR.BinOp("+", new_lo, loop_length, T.index, s.srcinfo) @@ -618,39 +606,12 @@ def static_check(): lambda _: new_iter, only_replace_attrs=False, ) - return ir, fwd - - def dynamic_check(): - s = loop_c._node - - assert isinstance(s, LoopIR.For) - - loop_length = LoopIR.BinOp("-", s.hi, s.lo, T.index, s.srcinfo) - new_hi = LoopIR.BinOp("+", new_lo, loop_length, T.index, s.srcinfo) - - ir, fwd1 = loop_c._child_node("lo")._replace(new_lo) - ir, fwd2 = fwd1(loop_c)._child_node("hi")._replace(new_hi) - fwd12 = _compose(fwd2, fwd1) - - # all uses of the loop iteration in the second body need - # to be offset by (`lo` - `new_lo``) - loop_iter = s.iter - iter_node = LoopIR.Read(loop_iter, [], T.index, s.srcinfo) - iter_offset = LoopIR.BinOp("-", s.lo, new_lo, T.index, s.srcinfo) - new_iter = LoopIR.BinOp("+", iter_node, iter_offset, T.index, s.srcinfo) - ir, fwd = _replace_reads( - ir, - fwd12, - loop_c, - loop_iter, - lambda _: new_iter, - only_replace_attrs=False, - ) - fuzz(loop_c, fwd) + if checker == "dynamic": + fuzz(loop_c.as_block(), fwd(loop_c).as_block()) return ir, fwd - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoProductLoop(outer_loop_c, new_name): @@ -848,85 +809,17 @@ def DoDivideWithRecompute( iter_i: str, check_mode: CheckMode, ): - proc = loop_cursor.get_root() - loop = loop_cursor._node - srcinfo = loop.srcinfo - - assert isinstance(loop, LoopIR.For) - assert isinstance(outer_hi, LoopIR.expr) - - def static_check(): - Check_IsIdempotent(proc, loop.body) - - def rd(i): - return LoopIR.Read(i, [], T.index, srcinfo) - - def cnst(intval): - return LoopIR.Const(intval, T.int, srcinfo) - - def szop(op, lhs, rhs): - return LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) - - sym_o = Sym(iter_o) - sym_i = Sym(iter_i) - x = cnst(outer_stride) - - if ( - isinstance(outer_hi, LoopIR.BinOp) - and outer_hi.op == "/" - and isinstance(outer_hi.rhs, LoopIR.Const) - and outer_hi.rhs.val == outer_stride - ): - N_before_recompute = szop("-", outer_hi.lhs, szop("%", outer_hi.lhs, x)) - else: - N_before_recompute = szop("*", outer_hi, x) - - N_recompute = LoopIR.BinOp("-", loop.hi, N_before_recompute, T.index, srcinfo) - try: - Check_IsNonNegativeExpr(proc, [loop], N_recompute) - except SchedulingError: - raise SchedulingError( - f"outer_hi * outer_stride exceeds loop's hi {loop.hi}" - ) - - hi_o = outer_hi - hi_i = szop("+", x, N_recompute) - - # turn current loop into outer loop - ir, fwd = loop_cursor._child_node("iter")._replace(sym_o) - ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(hi_o) - fwd = _compose(fwd_repl, fwd) - - # wrap body in inner loop - def inner_wrapper(body): - return LoopIR.For( - sym_i, - LoopIR.Const(0, T.index, srcinfo), - hi_i, - body, - LoopIR.Seq(), - srcinfo, - ) - - ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - - # replace the iteration variable in the body - def mk_iter(_): - return szop("+", szop("*", rd(sym_o), x), rd(sym_i)) + def check(checker: Checker): + proc = loop_cursor.get_root() + loop = loop_cursor._node + srcinfo = loop.srcinfo - ir, fwd = _replace_reads( - ir, - fwd, - loop_cursor, - loop.iter, - mk_iter, - only_replace_attrs=False, - ) + assert isinstance(loop, LoopIR.For) + assert isinstance(outer_hi, LoopIR.expr) - return ir, fwd + if checker == "static": + Check_IsIdempotent(proc, loop.body) - def dynamic_check(): def rd(i): return LoopIR.Read(i, [], T.index, srcinfo) @@ -951,6 +844,13 @@ def szop(op, lhs, rhs): N_before_recompute = szop("*", outer_hi, x) N_recompute = LoopIR.BinOp("-", loop.hi, N_before_recompute, T.index, srcinfo) + if checker == "static": + try: + Check_IsNonNegativeExpr(proc, [loop], N_recompute) + except SchedulingError: + raise SchedulingError( + f"outer_hi * outer_stride exceeds loop's hi {loop.hi}" + ) hi_o = outer_hi hi_i = szop("+", x, N_recompute) @@ -986,11 +886,12 @@ def mk_iter(_): mk_iter, only_replace_attrs=False, ) - fuzz(loop_cursor.parent(), fwd) + if checker == "dynamic": + fuzz(loop_cursor.as_block(), fwd(loop_cursor).as_block()) return ir, fwd - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoDivideLoop( @@ -1002,7 +903,7 @@ def DoDivideLoop( tail="guard", perfect=False, ): - def static_check(): + def check(checker: Checker): loop = loop_cursor._node N = loop.hi outer_i = Sym(outer_iter) @@ -1047,9 +948,10 @@ def ceildiv(lhs, rhs): elif tail_strategy in ["cut", "cut_and_guard"]: outer_hi = szop("/", N, inner_hi) # floor div elif tail_strategy == "perfect": - ir = loop_cursor.get_root() - loop = loop_cursor._node - Check_IsDivisible(ir, [loop], N, quot) + if checker == "static": + ir = loop_cursor.get_root() + loop = loop_cursor._node + Check_IsDivisible(ir, [loop], N, quot) outer_hi = divide_expr(N, quot) else: assert False, f"bad tail strategy: {tail_strategy}" @@ -1124,173 +1026,56 @@ def mk_main_iter(c): ir, fwd_ins = fwd(loop_cursor).after()._insert([cut_s]) fwd = _compose(fwd_ins, fwd) + if checker == "dynamic": + fuzz( + loop_cursor.as_block(), + ( + fwd(loop_cursor).as_block().expand(0, 1) + if tail_strategy in ["cut", "cut_and_guard"] + else fwd(loop_cursor).as_block() + ), + ) return ir, fwd - def dynamic_check(): - loop = loop_cursor._node - N = loop.hi - outer_i = Sym(outer_iter) - inner_i = Sym(inner_iter) - srcinfo = loop.srcinfo - tail_strategy = "perfect" if perfect else tail - - if not is_const_zero(loop.lo): - raise SchedulingError( - f"expected the lower bound of the loop to be zero, got {loop.lo}." - ) + return do_check(check, check_mode) - def substitute(srcinfo): - cnst = lambda x: LoopIR.Const(x, T.int, srcinfo) - rd = lambda x: LoopIR.Read(x, [], T.index, srcinfo) - op = lambda op, lhs, rhs: LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) - return op("+", op("*", cnst(quot), rd(outer_i)), rd(inner_i)) +def DoDivideLoopMin( + loop_cursor, + quot, + outer_iter, + inner_iter, + check_mode: CheckMode, +): + if check_mode != "dynamic": + raise SchedulingError("cannot use min tail strategy without chexo") + loop = loop_cursor._node + N = loop.hi + outer_i = Sym(outer_iter) + inner_i = Sym(inner_iter) + srcinfo = loop.srcinfo - # short-hands for sanity - def boolop(op, lhs, rhs, typ): - return LoopIR.BinOp(op, lhs, rhs, typ, srcinfo) + if not is_const_zero(loop.lo): + raise SchedulingError( + f"expected the lower bound of the loop to be zero, got {loop.lo}." + ) - def szop(op, lhs, rhs): - return LoopIR.BinOp(op, lhs, rhs, lhs.type, srcinfo) + def substitute(srcinfo): + rd = lambda x: LoopIR.Read(x, [], T.index, srcinfo) + op = lambda op, lhs, rhs: LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) - def cnst(intval): - return LoopIR.Const(intval, T.int, srcinfo) + return op("+", op("*", quot, rd(outer_i)), rd(inner_i)) - def rd(i): - return LoopIR.Read(i, [], T.index, srcinfo) + # short-hands for sanity + def boolop(op, lhs, rhs, typ): + return LoopIR.BinOp(op, lhs, rhs, typ, srcinfo) - def ceildiv(lhs, rhs): - assert isinstance(rhs, LoopIR.Const) and rhs.val > 0 - rhs_1 = cnst(rhs.val - 1) - return szop("/", szop("+", lhs, rhs_1), rhs) + def cnst(intval): + return LoopIR.Const(intval, T.int, srcinfo) - # determine hi and lo loop bounds - inner_hi = cnst(quot) - if tail_strategy in ["guard"]: - outer_hi = ceildiv(N, inner_hi) - elif tail_strategy in ["cut", "cut_and_guard"]: - outer_hi = szop("/", N, inner_hi) # floor div - elif tail_strategy == "perfect": - ir = loop_cursor.get_root() - loop = loop_cursor._node - outer_hi = divide_expr(N, quot) - else: - assert False, f"bad tail strategy: {tail_strategy}" - - # turn current loop into outer loop - ir, fwd = loop_cursor._child_node("iter")._replace(outer_i) - ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(outer_hi) - fwd = _compose(fwd_repl, fwd) - - # wrap body in a guard - if tail_strategy == "guard": - idx_sub = substitute(srcinfo) - - def guard_wrapper(body): - cond = boolop("<", idx_sub, N, T.bool) - return LoopIR.If(cond, body, [], srcinfo) - - ir, fwd_wrap = fwd(loop_cursor).body()._wrap(guard_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - - # wrap body in inner loop - def inner_wrapper(body): - return LoopIR.For( - inner_i, - LoopIR.Const(0, T.index, srcinfo), - inner_hi, - body, - loop.loop_mode, - srcinfo, - ) - - ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - - # replace the iteration variable in the body - def mk_main_iter(c): - return substitute(c._node.srcinfo) - - ir, fwd = _replace_reads( - ir, - fwd, - loop_cursor, - loop.iter, - mk_main_iter, - only_replace_attrs=False, - ) - - # add the tail case - if tail_strategy in ["cut", "cut_and_guard"]: - cut_i = Sym(inner_iter) - Ntail = szop("%", N, inner_hi) - - # in the tail loop we want the iteration variable to - # be mapped instead to (Ncut*Q + cut_i) - cut_tail_sub = szop("+", rd(cut_i), szop("*", outer_hi, inner_hi)) - - cut_body = Alpha_Rename(loop.body).result() - env = {loop.iter: cut_tail_sub} - cut_body = SubstArgs(cut_body, env).result() - - cut_s = LoopIR.For( - cut_i, - LoopIR.Const(0, T.index, srcinfo), - Ntail, - cut_body, - loop.loop_mode, - srcinfo, - ) - if tail_strategy == "cut_and_guard": - cond = boolop(">", Ntail, LoopIR.Const(0, T.int, srcinfo), T.bool) - cut_s = LoopIR.If(cond, [cut_s], [], srcinfo) - - ir, fwd_ins = fwd(loop_cursor).after()._insert([cut_s]) - fwd = _compose(fwd_ins, fwd) - - if tail_strategy == "perfect": - fuzz(loop_cursor.parent(), fwd) - return ir, fwd - - return do_check(static_check, dynamic_check, check_mode) - - -def DoDivideLoopMin( - loop_cursor, - quot, - outer_iter, - inner_iter, - check_mode: CheckMode, -): - if check_mode != "dynamic": - raise SchedulingError("cannot use min tail strategy without chexo") - loop = loop_cursor._node - N = loop.hi - outer_i = Sym(outer_iter) - inner_i = Sym(inner_iter) - srcinfo = loop.srcinfo - - if not is_const_zero(loop.lo): - raise SchedulingError( - f"expected the lower bound of the loop to be zero, got {loop.lo}." - ) - - def substitute(srcinfo): - rd = lambda x: LoopIR.Read(x, [], T.index, srcinfo) - op = lambda op, lhs, rhs: LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) - - return op("+", op("*", quot, rd(outer_i)), rd(inner_i)) - - # short-hands for sanity - def boolop(op, lhs, rhs, typ): - return LoopIR.BinOp(op, lhs, rhs, typ, srcinfo) - - def cnst(intval): - return LoopIR.Const(intval, T.int, srcinfo) - - def szop(op, lhs, rhs): - return LoopIR.BinOp(op, lhs, rhs, lhs.type, srcinfo) + def szop(op, lhs, rhs): + return LoopIR.BinOp(op, lhs, rhs, lhs.type, srcinfo) def rd(i): return LoopIR.Read(i, [], T.index, srcinfo) @@ -1354,7 +1139,7 @@ def mk_main_iter(c): only_replace_attrs=False, ) - fuzz(loop_cursor.parent(), fwd) + fuzz(loop_cursor.as_block(), fwd(loop_cursor).as_block()) return ir, fwd @@ -1637,7 +1422,7 @@ def mk_write(c): def DoConfigWrite( stmt_cursor, config, field, expr, check_mode: CheckMode, before=False ): - def static_check(): + def check(checker: Checker): assert isinstance(expr, (LoopIR.Read, LoopIR.StrideExpr, LoopIR.Const)) s = stmt_cursor._node @@ -1648,27 +1433,17 @@ def static_check(): else: ir, fwd = stmt_cursor.after()._insert([cw_s]) - cfg = Check_DeleteConfigWrite(ir, [cw_s]) - return ir, fwd, cfg - - def dynamic_check(): - assert isinstance(expr, (LoopIR.Read, LoopIR.StrideExpr, LoopIR.Const)) - s = stmt_cursor._node - - cw_s = LoopIR.WriteConfig(config, field, expr, s.srcinfo) - - if before: - ir, fwd1 = stmt_cursor.before()._insert([LoopIR.Pass(s.srcinfo)]) - pass_cursor = fwd1(stmt_cursor).prev() + if checker == "static": + cfg = Check_DeleteConfigWrite(ir, [cw_s]) else: - ir, fwd1 = stmt_cursor.after()._insert([LoopIR.Pass(s.srcinfo)]) - pass_cursor = fwd1(stmt_cursor).next() - ir, fwd2 = pass_cursor._replace([cw_s]) - - fuzz(pass_cursor, fwd2) - return ir, _compose(fwd2, fwd1), None + cfg = None + fuzz( + stmt_cursor.as_block(), + fwd(stmt_cursor).as_block().expand(*((1, 0) if before else (0, 1))), + ) + return ir, fwd, cfg - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) # --------------------------------------------------------------------------- # @@ -1677,7 +1452,7 @@ def dynamic_check(): def DoBindConfig(config, field, expr_cursor, check_mode): - def static_check(): + def check(checker: Checker): e = expr_cursor._node assert isinstance(e, LoopIR.Read) @@ -1688,7 +1463,10 @@ def static_check(): cfg_write_s = LoopIR.WriteConfig(config, field, e, e.srcinfo) ir, fwd = c.before()._insert([cfg_write_s]) - mod_cfg = Check_DeleteConfigWrite(ir, [cfg_write_s]) + if checker == "static": + mod_cfg = Check_DeleteConfigWrite(ir, [cfg_write_s]) + else: + mod_cfg = None cfg_f_type = config.lookup_type(field) cfg_read_e = LoopIR.ReadConfig(config, field, cfg_f_type, e.srcinfo) @@ -1697,32 +1475,13 @@ def static_check(): ir, fwd_repl = fwd(expr_cursor)._replace(cfg_read_e) fwd = _compose(fwd_repl, fwd) - Check_Aliasing(ir) + if checker == "static": + Check_Aliasing(ir) + else: + fuzz(c.as_block(), fwd(c).as_block().expand(1, 0)) return ir, fwd, mod_cfg - def dynamic_check(): - e = expr_cursor._node - - c = expr_cursor - while not isinstance(c._node, LoopIR.stmt): - c = c.parent() - - cfg_write_s = LoopIR.WriteConfig(config, field, e, e.srcinfo) - ir, fwd1 = c.before()._insert([LoopIR.Pass(e.srcinfo)]) - pass_cursor = fwd1(c).prev() - ir, fwd2 = pass_cursor._replace([cfg_write_s]) - new_expr_cursor = fwd2(fwd1(expr_cursor)) - - cfg_f_type = config.lookup_type(field) - cfg_read_e = LoopIR.ReadConfig(config, field, cfg_f_type, e.srcinfo) - if isinstance(expr_cursor.parent()._node, LoopIR.Call): - cfg_read_e = [cfg_read_e] - ir, fwd3 = new_expr_cursor._replace(cfg_read_e) - fwd = _compose(fwd3, _compose(fwd2, fwd1)) - fuzz(pass_cursor.as_block().expand(delta_hi=1), _compose(fwd3, fwd2)) - return ir, fwd, None - - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoCommuteExpr(expr_cursors): @@ -1773,80 +1532,29 @@ def match_parent(c1, c2): def DoRewriteExpr(expr_cursor, new_expr, check_mode): - ir, fwd = expr_cursor._replace(new_expr) - - def static_check(): - proc = expr_cursor.get_root() - s = get_enclosing_stmt_cursor(expr_cursor)._node - Check_ExprEqvInContext(proc, expr_cursor._node, [s], new_expr, [s]) - - def dynamic_check(): - fuzz(get_enclosing_stmt_cursor(expr_cursor), fwd) + def check(checker: Checker): + ir, fwd = expr_cursor._replace(new_expr) + if checker == "static": + proc = expr_cursor.get_root() + s = get_enclosing_stmt_cursor(expr_cursor)._node + Check_ExprEqvInContext(proc, expr_cursor._node, [s], new_expr, [s]) + else: + stmt_cursor = get_enclosing_stmt_cursor(expr_cursor) + fuzz(stmt_cursor.as_block(), fwd(stmt_cursor).as_block()) + return ir, fwd - do_check(static_check, dynamic_check, check_mode) - return ir, fwd + return do_check(check, check_mode) def DoBindExpr(new_name, expr_cursors, check_mode: CheckMode): - def static_check(): - assert expr_cursors - - expr = expr_cursors[0]._node - assert isinstance(expr, LoopIR.expr) - assert expr.type.is_numeric() - - expr_reads = [name for (name, typ) in get_reads_of_expr(expr)] - # TODO: dirty hack. need real CSE-equality (i.e. modulo srcinfo) - expr_cursors_eq = [c for c in expr_cursors if str(c._node) == str(expr)] - - init_s = get_enclosing_stmt_cursor(expr_cursors_eq[0]) - if len(expr_cursors_eq) > 1: - # TODO: Currently assume expr cursors is sorted in order - init_s, _ = match_parent(init_s, expr_cursors_eq[-1]) - - new_name_sym = Sym(new_name) - alloc_s = LoopIR.Alloc(new_name_sym, expr.type.basetype(), DRAM, expr.srcinfo) - assign_s = LoopIR.Assign( - new_name_sym, expr.type.basetype(), [], expr, expr.srcinfo - ) - ir, fwd = init_s.before()._insert([alloc_s, assign_s]) - - new_read = LoopIR.Read(new_name_sym, [], expr.type, expr.srcinfo) - first_write_c = None - for c in get_rest_of_block(init_s, inclusive=True): - for block in match_pattern(c, "_ = _") + match_pattern(c, "_ += _"): - assert len(block) == 1 - sc = block[0] - if sc._node.name in expr_reads: - first_write_c = sc - break - - if first_write_c and isinstance(c._node, (LoopIR.For, LoopIR.If)): - # Potentially unsafe to partially bind, err on side of caution for now - break - - while expr_cursors_eq and c.is_ancestor_of(expr_cursors_eq[0]): - ir, fwd_repl = _replace_helper( - fwd(expr_cursors_eq[0]), new_read, only_replace_attrs=False - ) - fwd = _compose(fwd_repl, fwd) - expr_cursors_eq.pop(0) - - if first_write_c: - break - - if len(expr_cursors_eq) > 0: - raise SchedulingError("Unsafe to bind all of the provided exprs.") - - Check_Aliasing(ir) - return ir, fwd - - def dynamic_check(): + def check(checker: Checker): assert expr_cursors expr = expr_cursors[0]._node assert isinstance(expr, LoopIR.expr) et = expr.type if expr.type.is_numeric() else T.i32 + if checker == "static": + assert expr.type.is_numeric() expr_reads = [name for (name, typ) in get_reads_of_expr(expr)] # TODO: dirty hack. need real CSE-equality (i.e. modulo srcinfo) @@ -1855,187 +1563,57 @@ def dynamic_check(): init_s = get_enclosing_stmt_cursor(expr_cursors_eq[0]) if len(expr_cursors_eq) > 1: # TODO: Currently assume expr cursors is sorted in order - init_s, _ = match_parent(init_s, expr_cursors_eq[-1]) - - new_name_sym = Sym(new_name) - alloc_s = LoopIR.Alloc(new_name_sym, et, DRAM, expr.srcinfo) - assign_s = LoopIR.Assign(new_name_sym, et, [], expr, expr.srcinfo) - ir, fwd1 = init_s.before()._insert([LoopIR.Pass(expr.srcinfo)]) - pass_cursor = fwd1(init_s).prev() - ir, fwd2 = pass_cursor._replace([alloc_s, assign_s]) - - new_read = LoopIR.Read(new_name_sym, [], et, expr.srcinfo) - for c in get_rest_of_block(init_s, inclusive=True): - while expr_cursors_eq and c.is_ancestor_of(expr_cursors_eq[0]): - ir, fwd_repl = _replace_helper( - fwd2(fwd1(expr_cursors_eq[0])), new_read, only_replace_attrs=False - ) - fwd2 = _compose(fwd_repl, fwd2) - expr_cursors_eq.pop(0) - - if len(expr_cursors_eq) > 0: - raise SchedulingError("Unsafe to bind all of the provided exprs.") - - fuzz(get_rest_of_block(pass_cursor, inclusive=True), fwd2) - return ir, _compose(fwd2, fwd1) - - return do_check(static_check, dynamic_check, check_mode) - - -def DoLiftScope(inner_c, check_mode: CheckMode): - def static_check(): - inner_s = inner_c._node - assert isinstance(inner_s, (LoopIR.If, LoopIR.For)) - target_type = "if statement" if isinstance(inner_s, LoopIR.If) else "for loop" - - outer_c = inner_c.parent() - if outer_c.root() == outer_c: - raise SchedulingError("Cannot lift scope of top-level statement") - outer_s = outer_c._node - - ir, fwd = inner_c.get_root(), lambda x: x - - if isinstance(outer_s, LoopIR.If): - - def if_wrapper(body, insert_orelse=False): - src = outer_s.srcinfo - # this is needed because _replace expects a non-zero length block - orelse = [LoopIR.Pass(src)] if insert_orelse else [] - return LoopIR.If(outer_s.cond, body, orelse, src) - - def orelse_wrapper(orelse): - src = outer_s.srcinfo - body = [LoopIR.Pass(src)] - return LoopIR.If(outer_s.cond, body, orelse, src) - - if isinstance(inner_s, LoopIR.If): - if inner_s in outer_s.body: - # if INNER: - # if OUTER: if OUTER: A - # if INNER: A else: C - # else: B ~> else: - # else: C if OUTER: B - # else: C - if len(outer_s.body) > 1: - raise SchedulingError( - f"expected {target_type} to be directly nested in parent" - ) - - blk_c = outer_s.orelse - wrapper = lambda body: if_wrapper(body, insert_orelse=bool(blk_c)) - - ir, fwd = inner_c.body()._wrap(wrapper, "body") - if blk_c: - ir, fwd_repl = fwd(inner_c).body()[0].orelse()._replace(blk_c) - fwd = _compose(fwd_repl, fwd) - - if inner_s.orelse: - ir, fwd_wrap = fwd(inner_c).orelse()._wrap(wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - if blk_c: - ir, fwd_repl = ( - fwd(inner_c).orelse()[0].orelse()._replace(blk_c) - ) - fwd = _compose(fwd_repl, fwd) - else: - # if INNER: - # if OUTER: A if OUTER: A - # else: else: B - # if INNER: B ~> else: - # else: C if OUTER: A - # else: C - assert inner_s in outer_s.orelse - if len(outer_s.orelse) > 1: - raise SchedulingError( - f"expected {target_type} to be directly nested in parent" - ) - - blk_a = outer_s.body - - ir, fwd = inner_c.body()._wrap(orelse_wrapper, "orelse") - ir, fwd_repl = fwd(inner_c).body()[0].body()._replace(blk_a) - fwd = _compose(fwd_repl, fwd) - - if inner_s.orelse: - ir, fwd_wrap = ( - fwd(inner_c).orelse()._wrap(orelse_wrapper, "orelse") - ) - fwd = _compose(fwd_wrap, fwd) - ir, fwd_repl = fwd(inner_c).orelse()[0].body()._replace(blk_a) - fwd = _compose(fwd_repl, fwd) - elif isinstance(inner_s, LoopIR.For): - # if OUTER: for INNER in _: - # for INNER in _: A ~> if OUTER: A - if len(outer_s.body) > 1: - raise SchedulingError( - f"expected {target_type} to be directly nested in parent" - ) - - if outer_s.orelse: - raise SchedulingError( - "cannot lift for loop when if has an orelse clause" - ) - - ir, fwd = inner_c.body()._move(inner_c.after()) - ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_move = fwd(outer_c)._move(fwd(inner_c).body()[0].after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(inner_c).body()[0]._delete() - fwd = _compose(fwd_del, fwd) - - return ir, fwd - - elif isinstance(outer_s, LoopIR.For): - if len(outer_s.body) > 1: - raise SchedulingError( - f"expected {target_type} to be directly nested in parent" - ) - - def loop_wrapper(body): - return outer_s.update(body=body) + init_s, _ = match_parent(init_s, expr_cursors_eq[-1]) - if isinstance(inner_s, LoopIR.If): - # for OUTER in _: if INNER: - # if INNER: A ~> for OUTER in _: A - # else: B else: - # for OUTER in _: B - if outer_s.iter in _FV(inner_s.cond): - raise SchedulingError("if statement depends on iteration variable") + new_name_sym = Sym(new_name) + alloc_s = LoopIR.Alloc(new_name_sym, expr.type.basetype(), DRAM, expr.srcinfo) + assign_s = LoopIR.Assign( + new_name_sym, expr.type.basetype(), [], expr, expr.srcinfo + ) + ir, fwd = init_s.before()._insert([alloc_s, assign_s]) - ir, fwd = inner_c.body()._wrap(loop_wrapper, "body") + new_read = LoopIR.Read(new_name_sym, [], expr.type, expr.srcinfo) + first_write_c = None + for c in get_rest_of_block(init_s, inclusive=True): + if checker == "static": + for block in match_pattern(c, "_ = _") + match_pattern(c, "_ += _"): + assert len(block) == 1 + sc = block[0] + if sc._node.name in expr_reads: + first_write_c = sc + break + + if first_write_c and isinstance(c._node, (LoopIR.For, LoopIR.If)): + # Potentially unsafe to partially bind, err on side of caution for now + break - if inner_s.orelse: - ir, fwd_wrap = fwd(inner_c).orelse()._wrap(loop_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - elif isinstance(inner_s, LoopIR.For): - # for OUTER in _: for INNER in _: - # for INNER in _: A ~> for OUTER in _: A - reads = get_reads_of_expr(inner_s.lo) + get_reads_of_expr(inner_s.hi) - if outer_s.iter in [name for name, _ in reads]: - raise SchedulingError( - "inner loop's lo or hi depends on outer loop's iteration variable" - ) + while expr_cursors_eq and c.is_ancestor_of(expr_cursors_eq[0]): + ir, fwd_repl = _replace_helper( + fwd(expr_cursors_eq[0]), new_read, only_replace_attrs=False + ) + fwd = _compose(fwd_repl, fwd) + expr_cursors_eq.pop(0) - Check_ReorderLoops(inner_c.get_root(), outer_s) - body = inner_c.body() - ir, fwd = inner_c._move(outer_c.after()) - ir, fwd_move = fwd(outer_c)._move(fwd(body).before()) - fwd = _compose(fwd_move, fwd) - ir, fwd_move = fwd(body)._move(fwd(outer_c).body().after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(outer_c).body()[0]._delete() - fwd = _compose(fwd_del, fwd) - return ir, fwd + if first_write_c: + break - ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(outer_c)._delete() - fwd = _compose(fwd_del, fwd) + if len(expr_cursors_eq) > 0: + raise SchedulingError("Unsafe to bind all of the provided exprs.") + if checker == "static": + Check_Aliasing(ir) + else: + fuzz( + init_s.as_block().expand(0, None), + fwd(init_s).as_block().expand(1, None), + ) return ir, fwd - def dynamic_check(): + return do_check(check, check_mode) + + +def DoLiftScope(inner_c, check_mode: CheckMode): + def check(checker: Checker): inner_s = inner_c._node assert isinstance(inner_s, (LoopIR.If, LoopIR.For)) target_type = "if statement" if isinstance(inner_s, LoopIR.If) else "for loop" @@ -2136,6 +1714,8 @@ def orelse_wrapper(orelse): ir, fwd_del = fwd(inner_c).body()[0]._delete() fwd = _compose(fwd_del, fwd) + if checker == "dynamic": + fuzz(outer_c.as_block(), fwd(inner_c).as_block()) return ir, fwd elif isinstance(outer_s, LoopIR.For): @@ -2169,6 +1749,8 @@ def loop_wrapper(body): "inner loop's lo or hi depends on outer loop's iteration variable" ) + if checker == "static": + Check_ReorderLoops(inner_c.get_root(), outer_s) body = inner_c.body() ir, fwd = inner_c._move(outer_c.after()) ir, fwd_move = fwd(outer_c)._move(fwd(body).before()) @@ -2177,17 +1759,21 @@ def loop_wrapper(body): fwd = _compose(fwd_move, fwd) ir, fwd_del = fwd(outer_c).body()[0]._delete() fwd = _compose(fwd_del, fwd) + if checker == "dynamic": + fuzz(outer_c.as_block(), fwd(inner_c).as_block()) return ir, fwd ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) fwd = _compose(fwd_move, fwd) ir, fwd_del = fwd(outer_c)._delete() fwd = _compose(fwd_del, fwd) - fuzz(outer_c.parent(), fwd) + + if checker == "dynamic": + fuzz(outer_c.as_block(), fwd(inner_c).as_block()) return ir, fwd - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoLiftConstant(assign_c, loop_c): @@ -2320,57 +1906,14 @@ def reduces_have_same_constant(s1, s2): def DoExpandDim(alloc_cursor, alloc_dim, indexing, check_mode: CheckMode): - alloc_s = alloc_cursor._node - assert isinstance(alloc_s, LoopIR.Alloc) - assert isinstance(alloc_dim, LoopIR.expr) - assert isinstance(indexing, LoopIR.expr) - - def static_check(): - Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], alloc_dim) - - old_typ = alloc_s.type - new_rngs = [alloc_dim] - if isinstance(old_typ, T.Tensor): - new_rngs += old_typ.shape() - basetyp = old_typ.basetype() - new_typ = T.Tensor(new_rngs, False, basetyp) - new_alloc = alloc_s.update(type=new_typ) - - ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) - - def mk_read(c): - rd = c._node - - # TODO: do I need to worry about Builtins too? - if isinstance(c.parent()._node, (LoopIR.Call)) and not rd.idx: - raise SchedulingError( - "TODO: Please Contact the developers to fix (i.e. add) " - "support for passing windows to scalar arguments" - ) - - if isinstance(rd, LoopIR.Read): - return {"idx": [indexing] + rd.idx} - elif isinstance(rd, LoopIR.WindowExpr): - return {"idx": [LoopIR.Point(indexing, rd.srcinfo)] + rd.idx} - else: - raise NotImplementedError( - f"Did not implement {type(rd)}. This may be a bug." - ) - - def mk_write(c): - s = c._node - return {"idx": [indexing] + s.idx} - - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) - - after_alloc = [c._node for c in get_rest_of_block(fwd(alloc_cursor))] + def check(checker: Checker): + alloc_s = alloc_cursor._node + assert isinstance(alloc_s, LoopIR.Alloc) + assert isinstance(alloc_dim, LoopIR.expr) + assert isinstance(indexing, LoopIR.expr) - Check_Bounds(ir, new_alloc, after_alloc) - return ir, fwd - - def dynamic_check(): + if checker == "static": + Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], alloc_dim) old_typ = alloc_s.type new_rngs = [alloc_dim] @@ -2411,10 +1954,13 @@ def mk_write(c): after_alloc = [c._node for c in get_rest_of_block(fwd(alloc_cursor))] - fuzz(get_rest_of_block(alloc_cursor), fwd) + if checker == "static": + Check_Bounds(ir, new_alloc, after_alloc) + else: + fuzz(get_rest_of_block(alloc_cursor), get_rest_of_block(fwd(alloc_cursor))) return ir, fwd - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoResizeDim( @@ -2424,68 +1970,14 @@ def DoResizeDim( offset: LoopIR.expr, check_mode: CheckMode, ): - alloc_s = alloc_cursor._node - alloc_name = alloc_s.name - assert isinstance(alloc_s, LoopIR.Alloc) - assert isinstance(alloc_s.type, T.Tensor) - - def static_check(): - Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], size) - - ir, fwd = ( - alloc_cursor._child_node("type") - ._child_block("hi")[dim_idx] - ._replace([size]) - ) - - def mk_read(c): - rd = c._node - - def mk_binop(e): - return LoopIR.BinOp("-", e, offset, offset.type, rd.srcinfo) - - new_idx = rd.idx.copy() - if isinstance(rd, LoopIR.Read): - new_idx[dim_idx] = mk_binop(rd.idx[dim_idx]) - return {"idx": new_idx} - - elif isinstance(rd, LoopIR.WindowExpr): - if isinstance(rd.idx[dim_idx], LoopIR.Point): - new_idx[dim_idx] = LoopIR.Point( - mk_binop(rd.idx[dim_idx].pt), rd.srcinfo - ) - else: - new_idx[dim_idx] = LoopIR.Interval( - mk_binop(rd.idx[dim_idx].lo), - mk_binop(rd.idx[dim_idx].hi), - rd.srcinfo, - ) - - return {"idx": new_idx} - else: - raise NotImplementedError( - f"Did not implement {type(rd)}. This may be a bug." - ) - - def mk_write(c): - s = c._node - new_idx = s.idx.copy() - new_idx[dim_idx] = LoopIR.BinOp( - "-", s.idx[dim_idx], offset, offset.type, s.srcinfo - ) - return {"idx": new_idx} - - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) - - new_alloc_cursor = fwd(alloc_cursor) - after_alloc = [c._node for c in get_rest_of_block(new_alloc_cursor)] + def check(checker: Checker): + alloc_s = alloc_cursor._node + alloc_name = alloc_s.name + assert isinstance(alloc_s, LoopIR.Alloc) + assert isinstance(alloc_s.type, T.Tensor) - Check_Bounds(ir, new_alloc_cursor._node, after_alloc) - return ir, fwd - - def dynamic_check(): + if checker == "static": + Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], size) ir, fwd = ( alloc_cursor._child_node("type") @@ -2537,10 +2029,13 @@ def mk_write(c): new_alloc_cursor = fwd(alloc_cursor) after_alloc = [c._node for c in get_rest_of_block(new_alloc_cursor)] - fuzz(get_rest_of_block(alloc_cursor), fwd) + if checker == "static": + Check_Bounds(ir, new_alloc_cursor._node, after_alloc) + else: + fuzz(get_rest_of_block(alloc_cursor), get_rest_of_block(new_alloc_cursor)) return ir, fwd - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoRearrangeDim(decl_cursor, permute_vector): @@ -2623,65 +2118,20 @@ def mk_stride_expr(c): def DoDivideDim(alloc_cursor, dim_idx, quotient, check_mode: CheckMode): - alloc_s = alloc_cursor._node - alloc_sym = alloc_s.name - - assert isinstance(alloc_s, LoopIR.Alloc) - assert isinstance(dim_idx, int) - assert isinstance(quotient, int) - - old_typ = alloc_s.type - old_shp = old_typ.shape() - dim = old_shp[dim_idx] - - def static_check(): - Check_IsDivisible(alloc_cursor.get_root(), [alloc_s], dim, quotient) - numer = divide_expr(dim, quotient) - new_shp = ( - old_shp[:dim_idx] - + [ - numer, - LoopIR.Const(quotient, T.int, dim.srcinfo), - ] - + old_shp[dim_idx + 1 :] - ) - new_typ = T.Tensor(new_shp, False, old_typ.basetype()) - - ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) - - def remap_idx(idx): - orig_i = idx[dim_idx] - srcinfo = orig_i.srcinfo - quot = LoopIR.Const(quotient, T.int, srcinfo) - hi = LoopIR.BinOp("/", orig_i, quot, orig_i.type, srcinfo) - lo = LoopIR.BinOp("%", orig_i, quot, orig_i.type, srcinfo) - return idx[:dim_idx] + [hi, lo] + idx[dim_idx + 1 :] - - def mk_read(c): - rd = c._node - - if isinstance(rd, LoopIR.Read) and not rd.idx: - raise SchedulingError( - f"Cannot divide {alloc_sym} because buffer is passed as an argument" - ) - elif isinstance(rd, LoopIR.WindowExpr): - raise SchedulingError( - f"Cannot divide {alloc_sym} because the buffer is windowed later on" - ) - - return {"idx": remap_idx(rd.idx)} + def check(checker: Checker): + alloc_s = alloc_cursor._node + alloc_sym = alloc_s.name - def mk_write(c): - s = c._node - return {"idx": remap_idx(s.idx)} + assert isinstance(alloc_s, LoopIR.Alloc) + assert isinstance(dim_idx, int) + assert isinstance(quotient, int) - # TODO: add better iteration primitive - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) - return ir, fwd + old_typ = alloc_s.type + old_shp = old_typ.shape() + dim = old_shp[dim_idx] - def dynamic_check(): + if checker == "static": + Check_IsDivisible(alloc_cursor.get_root(), [alloc_s], dim, quotient) numer = divide_expr(dim, quotient) new_shp = ( old_shp[:dim_idx] @@ -2725,10 +2175,11 @@ def mk_write(c): for c in get_rest_of_block(alloc_cursor): ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) - fuzz(get_rest_of_block(alloc_cursor), fwd) + if checker == "dynamic": + fuzz(get_rest_of_block(alloc_cursor), get_rest_of_block(fwd(alloc_cursor))) return ir, fwd - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoMultiplyDim(alloc_cursor, hi_idx, lo_idx): @@ -3210,7 +2661,7 @@ def _stmt(s): def DoRemoveLoop(loop, unsafe_disable_check, check_mode): - def static_check(): + def check(checker: Checker): s = loop._node # Check if we can remove the loop. Conditions are: @@ -3221,66 +2672,32 @@ def static_check(): ) # 2. Body is idempotent - if not unsafe_disable_check: + if not unsafe_disable_check and checker == "static": Check_IsIdempotent(loop.get_root(), [s]) # 3. The loop runs at least once; # If not, then place a guard around the statement - ir, fwd = loop.get_root(), lambda x: x + ir, fwd_move = loop.body()._move(loop.after()) + ir, fwd_del = fwd(loop)._delete() + fwd = _compose(fwd_del, fwd_move) try: - Check_IsPositiveExpr(loop.get_root(), [s], s.hi) + if checker == "static": + Check_IsPositiveExpr(loop.get_root(), [s], s.hi) + else: + fuzz(loop.as_block(), fwd(loop.body())) except SchedulingError: cond = LoopIR.BinOp(">", s.hi, s.lo, T.bool, s.srcinfo) def wrapper(body): return LoopIR.If(cond, body, [], s.srcinfo) - ir, fwd = loop.body()._wrap(wrapper, "body") - - ir, fwd_move = fwd(loop).body()._move(fwd(loop).after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(loop)._delete() - fwd = _compose(fwd_del, fwd) + ir, fwd = fwd(loop.body())._wrap(wrapper, "body") + if checker == "dynamic": + fuzz(loop.as_block(), fwd(loop.body()).parent().as_block()) return ir, fwd - def dynamic_check(): - s = loop._node - - # Check if we can remove the loop. Conditions are: - # 1. Body does not depend on the loop iteration variable - if s.iter in _FV(s.body): - raise SchedulingError( - f"Cannot remove loop, {s.iter} is not " "free in the loop body." - ) - - # 2. Body is idempotent - - # 3. The loop runs at least once; - # If not, then place a guard around the statement - ir1, fwd1 = loop.get_root(), lambda x: x - ir1, fwd_move1 = fwd1(loop).body()._move(fwd1(loop).after()) - fwd1 = _compose(fwd_move1, fwd1) - ir1, fwd_del1 = fwd1(loop)._delete() - fwd1 = _compose(fwd_del1, fwd1) - try: - fuzz(loop.parent(), fwd1) - return ir1, fwd1 - except SchedulingError: - - def wrapper(body): - return LoopIR.If(cond, body, [], s.srcinfo) - - ir2, fwd2 = loop.body()._wrap(wrapper, "body") - ir2, fwd_move2 = fwd2(loop).body()._move(fwd2(loop).after()) - fwd2 = _compose(fwd_move2, fwd2) - ir2, fwd_del2 = fwd2(loop)._delete() - fwd2 = _compose(fwd_del2, fwd2) - cond = LoopIR.BinOp(">", s.hi, s.lo, T.bool, s.srcinfo) - fuzz(loop.parent(), fwd2) - return ir2, fwd2 - - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) # This is same as original FissionAfter, except that diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo.py index c19fdc088..96fd73ae5 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo.py @@ -392,6 +392,8 @@ def generate_test_case( class TestResult: buffer_values: dict[Sym, np.ndarray] ctxt_object: dict[str, Union[int, float]] + # map from JavaScript name of variable tracking coverage of different parts of coverage skeleton + # e.g. bool variable tracking whether branch of if statement gets executed coverage_result: Optional[dict[str, Union[bool, memoryview]]] @@ -494,6 +496,18 @@ def forward_to_test(self, cursor: Block) -> Optional[Block]: ) return None + def forward_staging_args( + self, staging_args: Optional[StageMemArgs] + ) -> Optional[StageMemArgs]: + if staging_args is None: + return None + forwarded_scope = self.forward_to_test(staging_args.scope) + if forwarded_scope is None: + return None + return StageMemArgs( + staging_args.buffer_sym, staging_args.staged_window_expr, forwarded_scope + ) + def backward_from_test(self, path: NodePath) -> NodePath: assert path.path[0][1] is not None return NodePath( @@ -525,11 +539,6 @@ def broaden(self) -> Optional["TestScope"]: else: return TestScope(self.scope._anchor.as_block()) - def transform(self, forward: Callable[[Cursor], Cursor]) -> "TestScope": - if self.broaden() is None: - return TestScope(forward(self.scope._anchor)._child_block("body")) - return TestScope(forward(self.scope)) - def get_type_map(self) -> dict[Sym, LoopIR.type]: root_proc = self.scope.get_root() proc_type_visitor = TypeVisitor() @@ -590,141 +599,72 @@ def get_test_spec( TEST_CASE_BOUND = 15 -MAX_FAILS = 3 +MAX_SKIPPED_TESTS = 3 MAX_ITERS = 20 -TIME_RECORDING_FILE = None # "./times.csv" - - -@dataclass -class Timer: - fuzz_start: Optional[int] = None - transpile_start: Optional[int] = None - constraint_start: Optional[int] = None - test_start: Optional[int] = None - scope_widen_count: int = 0 - fuzz_total: int = 0 - transpile_total: int = 0 - constraint_total: int = 0 - test_total: int = 0 - - def start_fuzz(self): - if TIME_RECORDING_FILE is None: - return - self.fuzz_start = time.process_time_ns() - - def widen_scope(self): - self.scope_widen_count += 1 - - def end_fuzz(self): - if TIME_RECORDING_FILE is None: - return - assert self.fuzz_start is not None - self.fuzz_total += time.process_time_ns() - self.fuzz_start - self.fuzz_start = None - - def start_transpile(self): - if TIME_RECORDING_FILE is None: - return - self.transpile_start = time.process_time_ns() - - def end_transpile(self): - if TIME_RECORDING_FILE is None: - return - assert self.transpile_start is not None - self.transpile_total += time.process_time_ns() - self.transpile_start - self.transpile_start = None - - def start_constraint(self): - if TIME_RECORDING_FILE is None: - return - self.constraint_start = time.process_time_ns() - - def end_constraint(self): - if TIME_RECORDING_FILE is None: - return - assert self.constraint_start is not None - self.constraint_total += time.process_time_ns() - self.constraint_start - self.constraint_start = None - def start_test(self): - if TIME_RECORDING_FILE is None: - return - self.test_start = time.process_time_ns() - - def end_test(self): - if TIME_RECORDING_FILE is None: - return - assert self.test_start is not None - self.test_total += time.process_time_ns() - self.test_start - self.test_start = None - - def record(self, failed: bool, unsolved: bool): - if TIME_RECORDING_FILE is None: - return - with open(TIME_RECORDING_FILE, "a") as recording: - recording.write( - f"{self.fuzz_total},{self.transpile_total},{self.constraint_total},{self.test_total},{self.scope_widen_count},{1 if failed else 0},{1 if unsolved else 0}\n" - ) - - -def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): - timer = Timer() - timer.start_fuzz() - starting_scope = Cursor.create(starting_scope.get_root()) - if isinstance(starting_scope, Node) and starting_scope.depth() == 0: - starting_scope = starting_scope.body() - - starting_scope = ( - starting_scope.as_block() - if isinstance(starting_scope, Node) - else starting_scope - ) - failure_scope = starting_scope - failure_transformed_scope = fwd(failure_scope) - assert isinstance(failure_transformed_scope, Block) - cur_scope = TestScope(starting_scope) - cur_type_map = cur_scope.get_type_map() - transformed_type_map = cur_scope.transform(fwd).get_type_map() - - while cur_scope is not None: - timer.start_transpile() - transformed = cur_scope.transform(fwd) - cm = ConstraintMaker(cur_type_map | transformed_type_map) - spec1 = cur_scope.get_test_spec(cm, cur_type_map) - spec2 = transformed.get_test_spec(cm, transformed_type_map) +def fuzz( + scope1: Block, + scope2: Block, + staging_args: Optional[StageMemArgs] = None, +): + """ + scope1: smallest scope containing all changes made by scheduling op in original program + scope2: scope corresponding to starting scope in transformed program + staging_args: arguments to stage_mem scheduling op + """ + cur_scope1 = TestScope(scope1) + cur_scope2 = TestScope(scope2) + cur_type_map1 = cur_scope1.get_type_map() + cur_type_map2 = cur_scope2.get_type_map() + + while cur_scope1 is not None: + assert cur_scope2 is not None + cm = ConstraintMaker(cur_type_map1 | cur_type_map2) + + spec1 = cur_scope1.get_test_spec(cm, cur_type_map1) + spec2 = cur_scope2.get_test_spec(cm, cur_type_map2) transpiled_test1 = Transpiler( + # new proc that contains the current scope as a body, not the entire proc spec1.proc, - CoverageArgs(cm, spec1.var_renaming, spec1.forward_to_test(failure_scope)), + CoverageArgs( + cm, + spec1.var_renaming, + spec1.forward_to_test(scope1), + spec1.forward_staging_args(staging_args), + ), ) transpiled_test2 = Transpiler( spec2.proc, CoverageArgs( - cm, spec2.var_renaming, spec2.forward_to_test(failure_transformed_scope) + cm, + spec2.var_renaming, + spec2.forward_to_test(scope2), + spec2.forward_staging_args(staging_args), ), ) config_fields = transpiled_test1.get_configs() | transpiled_test2.get_configs() arg_types = spec1.arg_types | spec2.arg_types + # precondition of current scope in both original and transformed program constraint = spec1.constraint.union(spec2.constraint) skeleton1, skeleton2 = ( transpiled_test1.get_coverage_skeleton(), transpiled_test2.get_coverage_skeleton(), ) assert skeleton1 is not None and skeleton2 is not None + # symbolic representation of control flow in both original and transformed scope coverage_skeleton = skeleton1.merge(skeleton2) tests_passed = True - fails = 0 + skipped_tests = 0 iters = 0 - timer.end_transpile() while ( not coverage_skeleton.get_coverage_progress().is_finished() and iters < MAX_ITERS and tests_passed ): - timer.start_constraint() test_case = generate_test_case( arg_types, config_fields, @@ -732,22 +672,20 @@ def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): coverage_skeleton, cm, ) - timer.end_constraint() + # if constraint is unsolvable if test_case is None: - fails += 1 - if fails > MAX_FAILS: - timer.end_fuzz() - timer.record(False, True) - return + skipped_tests += 1 + if skipped_tests > MAX_SKIPPED_TESTS: + # program should pass but not testing it is probably bad + assert False else: continue - timer.start_test() out1 = run_test_case(test_case, transpiled_test1) out2 = run_test_case(test_case, transpiled_test2) if out1 == "failed" or out2 == "failed": + # precondition in called subproc failed or out of bounds access tests_passed = False - timer.end_test() break assert out1.coverage_result is not None and out2.coverage_result is not None coverage_skeleton.update_coverage( @@ -759,31 +697,17 @@ def fuzz(starting_scope: Union[Block, Node], fwd: Callable[[Cursor], Cursor]): ): tests_passed = False break - if cur_scope.broaden() is not None: - for ctxt_name in out1.ctxt_object & out2.ctxt_object.keys(): + if cur_scope1.broaden() is not None: + for ctxt_name in out1.ctxt_object.keys() & out2.ctxt_object.keys(): if not np.allclose( out1.ctxt_object[ctxt_name], out2.ctxt_object[ctxt_name] ): tests_passed = False break - timer.end_test() iters += 1 if tests_passed: - timer.end_fuzz() - timer.record(False, False) return else: - timer.widen_scope() - cur_scope = cur_scope.broaden() - timer.end_fuzz() - timer.record(True, False) + cur_scope1 = cur_scope1.broaden() + cur_scope2 = cur_scope2.broaden() raise SchedulingError("tests failed at broadest scope") - - -def fuzz_reorder_stmts(s1: Node, s2: Node): - starting_scope = s1.as_block().expand(0, 1) - _, fwd = s2._move(s1.before()) - patched_fwd = lambda cursor: ( - fwd(cursor) if isinstance(cursor, Node) else fwd(s2).as_block().expand(0, 1) - ) - fuzz(starting_scope, patched_fwd) From c5957a335e6dc5f39bb30fc3446273a58a9ef3d7 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Mon, 26 May 2025 17:13:43 -0400 Subject: [PATCH 20/24] idk --- src/exo/backend/LoopIR_compiler.py | 6 ++++++ src/exo/rewrite/LoopIR_scheduling.py | 8 +++----- tests/test_apps.py | 3 +-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/exo/backend/LoopIR_compiler.py b/src/exo/backend/LoopIR_compiler.py index 72586a8f2..fa52ab6be 100644 --- a/src/exo/backend/LoopIR_compiler.py +++ b/src/exo/backend/LoopIR_compiler.py @@ -15,6 +15,8 @@ from ..core.prelude import * from .win_analysis import WindowAnalysis from ..rewrite.range_analysis import IndexRangeEnvironment +from ..rewrite.chexo import fuzz +from ..core.internal_cursors import Cursor DEFAULT_CHECK_MODE = "both" @@ -415,7 +417,11 @@ def from_lines(x): is_public_decl = id(p) in orig_procs if check_mode != "dynamic": + # fixme: need to check parallel analysis on static procs even if one of the procs is dynamic p = ParallelAnalysis().run(p) + else: + proc_cursor = Cursor.create(p).body() + fuzz(proc_cursor, proc_cursor) p = PrecisionAnalysis().run(p) p = WindowAnalysis().apply_proc(p) p = MemoryAnalysis().run(p) diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index 22a4e1561..11d8adc5e 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -434,13 +434,13 @@ def divide_expr(e, quot): # --------------------------------------------------------------------------- # # Scheduling directives -T = TypeVar("T") +DoCheckReturn = TypeVar("DoCheckReturn") def do_check( - check: Callable[[Checker], T], + check: Callable[[Checker], DoCheckReturn], mode: CheckMode, -) -> T: +) -> DoCheckReturn: if mode == "both": e_static, e_dynamic = None, None trb_static, trb_dynamic = None, None @@ -494,8 +494,6 @@ def check(checker: Checker): def DoParallelizeLoop(loop_cursor, check_mode: CheckMode): def check(checker: Checker): ir, fwd = loop_cursor._child_node("loop_mode")._replace(LoopIR.Par()) - if checker == "dynamic": - fuzz(loop_cursor.as_block(), fwd(loop_cursor).as_block()) return ir, fwd return do_check(check, check_mode) diff --git a/tests/test_apps.py b/tests/test_apps.py index ed5fac092..960fec521 100644 --- a/tests/test_apps.py +++ b/tests/test_apps.py @@ -47,8 +47,7 @@ def test_gemmini_matmul(golden): @pytest.mark.slow def test_gemmini_conv(golden): module_file = REPO_ROOT / "apps" / "gemmini" / "src" / "exo" / "conv.py" - # TODO: uncomment when conv is fixed in main - # assert _test_app(module_file) == golden + assert _test_app(module_file) == golden def test_blur(golden): From 9385bbc778b496f399ef4a244e72b7444a8e3ce7 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Mon, 26 May 2025 17:19:42 -0400 Subject: [PATCH 21/24] move things --- src/exo/backend/LoopIR_compiler.py | 2 +- src/exo/libs/externs.py | 2 +- src/exo/rewrite/LoopIR_scheduling.py | 2 +- .../chexo}/LoopIR_transpiler.py | 10 +++++----- src/exo/rewrite/{ => chexo}/chexo.py | 16 ++++++++-------- src/exo/rewrite/{ => chexo}/constraint_solver.py | 8 ++++---- src/exo/{backend => rewrite/chexo}/coverage.py | 6 +++--- tests/test_chexo.py | 4 ++-- tests/test_constraint_solver.py | 4 ++-- tests/test_transpiler.py | 6 +++--- 10 files changed, 30 insertions(+), 30 deletions(-) rename src/exo/{backend => rewrite/chexo}/LoopIR_transpiler.py (99%) rename src/exo/rewrite/{ => chexo}/chexo.py (98%) rename src/exo/rewrite/{ => chexo}/constraint_solver.py (99%) rename src/exo/{backend => rewrite/chexo}/coverage.py (99%) diff --git a/src/exo/backend/LoopIR_compiler.py b/src/exo/backend/LoopIR_compiler.py index fa52ab6be..fe016d446 100644 --- a/src/exo/backend/LoopIR_compiler.py +++ b/src/exo/backend/LoopIR_compiler.py @@ -15,7 +15,7 @@ from ..core.prelude import * from .win_analysis import WindowAnalysis from ..rewrite.range_analysis import IndexRangeEnvironment -from ..rewrite.chexo import fuzz +from ..rewrite.chexo.chexo import fuzz from ..core.internal_cursors import Cursor DEFAULT_CHECK_MODE = "both" diff --git a/src/exo/libs/externs.py b/src/exo/libs/externs.py index cc9bbf78d..472224db1 100644 --- a/src/exo/libs/externs.py +++ b/src/exo/libs/externs.py @@ -1,7 +1,7 @@ from ..core.extern import Extern, _EErr import numpy as np -from ..rewrite.constraint_solver import Constraint, DisjointConstraint, Expression +from ..rewrite.chexo.constraint_solver import Constraint, DisjointConstraint, Expression from ..core.prelude import Sym diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index 11d8adc5e..b75d6bbe0 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -43,7 +43,7 @@ Check_ExprBound, Check_Aliasing, ) -from .chexo import fuzz +from .chexo.chexo import fuzz from .range_analysis import IndexRangeEnvironment, IndexRange, index_range_analysis from ..core.internal_cursors import Block, Node diff --git a/src/exo/backend/LoopIR_transpiler.py b/src/exo/rewrite/chexo/LoopIR_transpiler.py similarity index 99% rename from src/exo/backend/LoopIR_transpiler.py rename to src/exo/rewrite/chexo/LoopIR_transpiler.py index 3d9c427a4..999b59a5c 100644 --- a/src/exo/backend/LoopIR_transpiler.py +++ b/src/exo/rewrite/chexo/LoopIR_transpiler.py @@ -3,10 +3,10 @@ from string import Template from typing import Any, Callable, Generator, Iterable, Optional, Union -from ..core.configs import Config +from ...core.configs import Config -from ..core.prelude import Sym -from ..core.LoopIR import LoopIR, T +from ...core.prelude import Sym +from ...core.LoopIR import LoopIR, T from .coverage import ( CoverageSkeleton, CoverageSkeletonNode, @@ -23,8 +23,8 @@ StagingOverlap, SymbolicWindowIndex, ) -from ..core.internal_cursors import Block, Cursor, Node, NodePath -from ..rewrite.constraint_solver import ( +from ...core.internal_cursors import Block, Cursor, Node, NodePath +from .constraint_solver import ( TRUE_CONSTRAINT, Constraint, ConstraintMaker, diff --git a/src/exo/rewrite/chexo.py b/src/exo/rewrite/chexo/chexo.py similarity index 98% rename from src/exo/rewrite/chexo.py rename to src/exo/rewrite/chexo/chexo.py index 96fd73ae5..54dc89c8e 100644 --- a/src/exo/rewrite/chexo.py +++ b/src/exo/rewrite/chexo/chexo.py @@ -2,19 +2,19 @@ import time from typing import Callable, Literal, Optional, Union -from ..core.internal_cursors import Cursor, Block, Node, NodePath +from ...core.internal_cursors import Cursor, Block, Node, NodePath -from ..backend.LoopIR_transpiler import CoverageArgs, StageMemArgs, Transpiler -from ..backend.coverage import CoverageSkeleton +from .LoopIR_transpiler import CoverageArgs, StageMemArgs, Transpiler +from .coverage import CoverageSkeleton -from ..core.configs import Config +from ...core.configs import Config -from ..core.LoopIR import LoopIR, T +from ...core.LoopIR import LoopIR, T from dataclasses import dataclass, field -from ..core.prelude import Sym, SrcInfo -from ..core.memory import DRAM, Memory +from ...core.prelude import Sym, SrcInfo +from ...core.memory import DRAM, Memory import numpy as np -from .new_eff import SchedulingError +from ..new_eff import SchedulingError from .constraint_solver import ( TRUE_CONSTRAINT, Constraint, diff --git a/src/exo/rewrite/constraint_solver.py b/src/exo/rewrite/chexo/constraint_solver.py similarity index 99% rename from src/exo/rewrite/constraint_solver.py rename to src/exo/rewrite/chexo/constraint_solver.py index bc106207f..be53ac8ce 100644 --- a/src/exo/rewrite/constraint_solver.py +++ b/src/exo/rewrite/chexo/constraint_solver.py @@ -1,10 +1,10 @@ from dataclasses import dataclass, field from typing import Callable, Literal, Union, Optional -from ..core.configs import Config -from ..core.prelude import Sym -from ..core.LoopIR import LoopIR, T -from ..core.extern import Extern +from ...core.configs import Config +from ...core.prelude import Sym +from ...core.LoopIR import LoopIR, T +from ...core.extern import Extern import numpy as np from scipy.optimize import linprog from hsnf import smith_normal_form diff --git a/src/exo/backend/coverage.py b/src/exo/rewrite/chexo/coverage.py similarity index 99% rename from src/exo/backend/coverage.py rename to src/exo/rewrite/chexo/coverage.py index 350154655..5d03059f2 100644 --- a/src/exo/backend/coverage.py +++ b/src/exo/rewrite/chexo/coverage.py @@ -3,7 +3,7 @@ from typing import Generator, Iterable, Optional, Union import numpy as np -from ..rewrite.constraint_solver import ( +from .constraint_solver import ( Constraint, ConstraintMaker, ConstraintTerm, @@ -12,8 +12,8 @@ Expression, Solution, ) -from ..core.prelude import Sym -from ..core.internal_cursors import Node, NodePath +from ...core.prelude import Sym +from ...core.internal_cursors import Node, NodePath @dataclass diff --git a/tests/test_chexo.py b/tests/test_chexo.py index 22c0ac3b2..7c6da8226 100644 --- a/tests/test_chexo.py +++ b/tests/test_chexo.py @@ -1,12 +1,12 @@ from __future__ import annotations from exo.core.prelude import Sym -from exo.rewrite.chexo import ( +from exo.rewrite.chexo.chexo import ( TypeVisitor, get_free_variables, collect_path_constraints, ) -from exo.rewrite.constraint_solver import ConstraintMaker +from exo.rewrite.chexo.constraint_solver import ConstraintMaker from exo import proc, config from exo.core.memory import StaticMemory diff --git a/tests/test_constraint_solver.py b/tests/test_constraint_solver.py index b597ede32..793690b85 100644 --- a/tests/test_constraint_solver.py +++ b/tests/test_constraint_solver.py @@ -1,10 +1,10 @@ from __future__ import annotations from exo.core.prelude import Sym -from exo.rewrite.constraint_solver import ConstraintMaker, DisjointConstraint +from exo.rewrite.chexo.constraint_solver import ConstraintMaker, DisjointConstraint from exo.core.LoopIR import T from exo import proc -from exo.rewrite.chexo import TypeVisitor +from exo.rewrite.chexo.chexo import TypeVisitor def stringify_proc_constraint(p, invert=False): diff --git a/tests/test_transpiler.py b/tests/test_transpiler.py index afc19cb3b..3b201c285 100644 --- a/tests/test_transpiler.py +++ b/tests/test_transpiler.py @@ -1,11 +1,11 @@ from __future__ import annotations from exo.core.prelude import Sym -from exo.rewrite.constraint_solver import ConstraintMaker, DisjointConstraint +from exo.rewrite.chexo.constraint_solver import ConstraintMaker, DisjointConstraint from exo.core.LoopIR import T from exo import proc -from exo.rewrite.chexo import TypeVisitor -from exo.backend.LoopIR_transpiler import Transpiler, CoverageArgs +from exo.rewrite.chexo.chexo import TypeVisitor +from exo.rewrite.chexo.LoopIR_transpiler import Transpiler, CoverageArgs def get_coverage_args(p) -> CoverageArgs: From ad08a9c1cbceabb11e20f39fc9a9e0903d93bc9c Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 27 May 2025 16:36:13 -0400 Subject: [PATCH 22/24] refactor + stagemem --- src/exo/API_scheduling.py | 17 +- src/exo/backend/LoopIR_compiler.py | 6 +- src/exo/rewrite/LoopIR_scheduling.py | 1077 +++++++++----------- src/exo/rewrite/chexo/LoopIR_transpiler.py | 45 +- src/exo/rewrite/chexo/chexo.py | 133 ++- src/exo/rewrite/chexo/constraint_solver.py | 2 +- src/exo/rewrite/chexo/coverage.py | 95 +- 7 files changed, 681 insertions(+), 694 deletions(-) diff --git a/src/exo/API_scheduling.py b/src/exo/API_scheduling.py index e71d7e031..ac66295b9 100644 --- a/src/exo/API_scheduling.py +++ b/src/exo/API_scheduling.py @@ -1259,7 +1259,7 @@ def delete_stmt(proc, stmt_cursor): rewrite: `s1 ; s2 ; s3 -> s1 ; s3` """ - ir, fwd = scheduling.DoDeleteStmt(proc._root(), stmt_cursor._impl, proc._check_mode) + ir, fwd = scheduling.DoDeleteStmt(stmt_cursor._impl, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1268,12 +1268,10 @@ def insert_mutate(proc, gap_cursor, buf_read, rhs, is_reduce): if not (isinstance(buf_read, LoopIR.Read) and len(buf_read.idx) == 0): raise SchedulingError() new_stmt = (LoopIR.Reduce if is_reduce else LoopIR.Assign)( - buf_read.name, buf_read.type, rhs, buf_read.srcinfo + buf_read.name, buf_read.type, [], rhs, buf_read.srcinfo ) - ir, fwd = scheduling.DoInsertStmt( - proc._root(), gap_cursor._impl, new_stmt, proc._check_mode - ) - return Procedure(ir, __provenance_eq_Procedure=proc, _forward=fwd) + ir, fwd = scheduling.DoInsertStmt(gap_cursor._impl, new_stmt, proc._check_mode) + return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @sched_op([GapCursorA, ConfigA, ConfigFieldA, NewExprA("gap_cursor")]) @@ -1697,7 +1695,12 @@ def stage_mem(proc, block_cursor, win_expr, new_buf_name, accum=False): """ buf_name, w_exprs = win_expr ir, fwd = scheduling.DoStageMem( - block_cursor._impl, buf_name, w_exprs, new_buf_name, use_accum_zero=accum + block_cursor._impl, + buf_name, + w_exprs, + new_buf_name, + proc._check_mode, + use_accum_zero=accum, ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) diff --git a/src/exo/backend/LoopIR_compiler.py b/src/exo/backend/LoopIR_compiler.py index fe016d446..7bd41eb5a 100644 --- a/src/exo/backend/LoopIR_compiler.py +++ b/src/exo/backend/LoopIR_compiler.py @@ -15,10 +15,10 @@ from ..core.prelude import * from .win_analysis import WindowAnalysis from ..rewrite.range_analysis import IndexRangeEnvironment -from ..rewrite.chexo.chexo import fuzz +from ..rewrite.chexo.chexo import fuzz, fuzz_single_scope from ..core.internal_cursors import Cursor -DEFAULT_CHECK_MODE = "both" +DEFAULT_CHECK_MODE = "static" def sanitize_str(s): @@ -421,7 +421,7 @@ def from_lines(x): p = ParallelAnalysis().run(p) else: proc_cursor = Cursor.create(p).body() - fuzz(proc_cursor, proc_cursor) + fuzz_single_scope(proc_cursor) p = PrecisionAnalysis().run(p) p = WindowAnalysis().apply_proc(p) p = MemoryAnalysis().run(p) diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index b75d6bbe0..2ad97208e 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -43,10 +43,11 @@ Check_ExprBound, Check_Aliasing, ) -from .chexo.chexo import fuzz +from .chexo.chexo import fuzz, fuzz_single_scope +from .chexo.LoopIR_transpiler import StageMemArgs, StagedWindowExpr from .range_analysis import IndexRangeEnvironment, IndexRange, index_range_analysis -from ..core.internal_cursors import Block, Node +from ..core.internal_cursors import Block, Cursor, Node from ..core.prelude import * from ..core.proc_eqv import get_strictest_eqv_proc @@ -434,24 +435,34 @@ def divide_expr(e, quot): # --------------------------------------------------------------------------- # # Scheduling directives -DoCheckReturn = TypeVar("DoCheckReturn") +CheckResult = Union[ + tuple[LoopIR.proc, Callable[[Cursor], Cursor]], + tuple[ + LoopIR.proc, + Callable[ + [Cursor], + Cursor, + ], + Any, + ], +] def do_check( - check: Callable[[Checker], DoCheckReturn], + check: Callable[[Checker], CheckResult], mode: CheckMode, -) -> DoCheckReturn: +) -> CheckResult: if mode == "both": e_static, e_dynamic = None, None trb_static, trb_dynamic = None, None - static_res = None + static_res, dynamic_res = None, None try: static_res = check("static") except Exception as e: e_static = e trb_static = traceback.format_exc() try: - check("dynamic") + dynamic_res = check("dynamic") except Exception as e: e_dynamic = e trb_dynamic = traceback.format_exc() @@ -462,6 +473,11 @@ def do_check( elif e_static is not None: raise e_static else: + assert static_res is not None and dynamic_res is not None + if str(static_res[0]) != str(dynamic_res[0]): + assert ( + False + ), f"resulting object code differs between static and dynamic.\nstatic:\n{str(static_res[0])}\ndynamic:\n{str(dynamic_res[0])}" return static_res else: return check(mode) @@ -1452,7 +1468,8 @@ def check(checker: Checker): def DoBindConfig(config, field, expr_cursor, check_mode): def check(checker: Checker): e = expr_cursor._node - assert isinstance(e, LoopIR.Read) + if checker == "static": + assert isinstance(e, LoopIR.Read) c = expr_cursor while not isinstance(c._node, LoopIR.stmt): @@ -1564,10 +1581,8 @@ def check(checker: Checker): init_s, _ = match_parent(init_s, expr_cursors_eq[-1]) new_name_sym = Sym(new_name) - alloc_s = LoopIR.Alloc(new_name_sym, expr.type.basetype(), DRAM, expr.srcinfo) - assign_s = LoopIR.Assign( - new_name_sym, expr.type.basetype(), [], expr, expr.srcinfo - ) + alloc_s = LoopIR.Alloc(new_name_sym, et.basetype(), DRAM, expr.srcinfo) + assign_s = LoopIR.Assign(new_name_sym, et.basetype(), [], expr, expr.srcinfo) ir, fwd = init_s.before()._insert([alloc_s, assign_s]) new_read = LoopIR.Read(new_name_sym, [], expr.type, expr.srcinfo) @@ -2702,7 +2717,7 @@ def wrapper(body): # this does not remove loop. We have separate remove_loop # operator for that purpose. def DoFissionAfterSimple(stmt_cursor, n_lifts_start, unsafe_disable_checks, check_mode): - def static_check(): + def check(checker: Checker): n_lifts = n_lifts_start tgt_stmt = stmt_cursor._node assert isinstance(tgt_stmt, LoopIR.stmt) @@ -2757,7 +2772,7 @@ def alloc_check(pre, post): # we must check whether the two parts of the # fission can commute appropriately no_loop_var_pre = par_s.iter not in _FV(pre) - if not unsafe_disable_checks: + if not unsafe_disable_checks and checker == "static": Check_FissionLoop(ir, par_s, pre, post, no_loop_var_pre) # we can skip the loop iteration if the @@ -2806,114 +2821,15 @@ def wrapper(orelse): cur_c = fwd_move(fwd_wrap(par_c)) + if checker == "dynamic": + scope_cursor = stmt_cursor + for _ in range(n_lifts_start): + scope_cursor = scope_cursor.parent() + if n_lifts_start > 0: + fuzz(scope_cursor.as_block(), cur_c.as_block().expand(0, 1)) return ir, fwd - def dynamic_check(): - n_lifts = n_lifts_start - tgt_stmt = stmt_cursor._node - assert isinstance(tgt_stmt, LoopIR.stmt) - assert is_pos_int(n_lifts) - - ir, fwd = stmt_cursor.get_root(), lambda x: x - - def alloc_check(pre, post): - if not _is_alloc_free(pre, post): - pre_allocs = {s.name for s in pre if isinstance(s, LoopIR.Alloc)} - post_FV = _FV(post) - for nm in pre_allocs: - if nm in post_FV: - raise SchedulingError( - f"Will not fission here, because " - f"doing so will hide the allocation " - f"of {nm} from a later use site." - ) - - cur_c = stmt_cursor - while n_lifts > 0: - n_lifts -= 1 - - idx = cur_c.get_index() + 1 - par_c = cur_c.parent() - par_s = par_c._node - - if isinstance(par_s, LoopIR.For): - pre_c = par_c.body()[:idx] - post_c = par_c.body()[idx:] - elif isinstance(par_s, LoopIR.If): - if cur_c._node in par_s.body: - pre_c = par_c.body()[:idx] - post_c = par_c.body()[idx:] - else: - pre_c = par_c.orelse()[:idx] - post_c = par_c.orelse()[idx:] - else: - raise SchedulingError( - "Can only lift past a for loop or an if statement" - ) - - pre = [s._node for s in pre_c] - post = [s._node for s in post_c] - - if not (pre and post): - continue - - alloc_check(pre, post) - - if isinstance(par_s, LoopIR.For): - # we must check whether the two parts of the - # fission can commute appropriately - no_loop_var_pre = par_s.iter not in _FV(pre) - - # we can skip the loop iteration if the - # body doesn't depend on the loop - # and the body is idempotent - - def wrapper(body): - return par_s.update(body=body) - - ir, fwd_wrap = post_c._wrap(wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - - post_c = fwd_wrap(par_c).body()[-1] - ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) - fwd = _compose(fwd_move, fwd) - - cur_c = fwd_move(fwd_wrap(par_c)) - elif isinstance(par_s, LoopIR.If): - if cur_c._node in par_s.body: - - def wrapper(body): - return par_s.update(body=body, orelse=[]) - - ir, fwd_wrap = pre_c._wrap(wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - - pre_c = fwd_wrap(par_c).body()[0] - ir, fwd_move = pre_c._move(fwd_wrap(par_c).before()) - fwd = _compose(fwd_move, fwd) - - cur_c = fwd_move(fwd_wrap(par_c)).prev() - else: - assert cur_c._node in par_s.orelse - - def wrapper(orelse): - return par_s.update( - body=[LoopIR.Pass(par_s.srcinfo)], orelse=orelse - ) - - ir, fwd_wrap = post_c._wrap(wrapper, "orelse") - fwd = _compose(fwd_wrap, fwd) - - post_c = fwd_wrap(par_c).orelse()[-1] - ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) - fwd = _compose(fwd_move, fwd) - - cur_c = fwd_move(fwd_wrap(par_c)) - - fuzz(stmt_cursor, fwd) - return ir, fwd - - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) # TODO: Deprecate this with the one above @@ -3100,32 +3016,32 @@ def are_allocs_used_after_block(): def DoFuseLoop(f_cursor, s_cursor, check_mode: CheckMode, unsafe_disable_check=False): - proc = f_cursor.get_root() + def check(checker: Checker): + proc = f_cursor.get_root() - if f_cursor.next() != s_cursor: - raise SchedulingError( - f"expected the two loops to be fused to come one right after the other. However, the statement after the first loop is:\n{f_cursor.next()._node}\n, not the provided second loop:\n {s_cursor._node}" - ) + if f_cursor.next() != s_cursor: + raise SchedulingError( + f"expected the two loops to be fused to come one right after the other. However, the statement after the first loop is:\n{f_cursor.next()._node}\n, not the provided second loop:\n {s_cursor._node}" + ) - # check if the loop bounds are equivalent - loop1 = f_cursor._node - loop2 = s_cursor._node - Check_ExprEqvInContext(proc, loop1.hi, [loop1], loop2.hi, [loop2]) + # check if the loop bounds are equivalent + loop1 = f_cursor._node + loop2 = s_cursor._node + Check_ExprEqvInContext(proc, loop1.hi, [loop1], loop2.hi, [loop2]) - def mk_read(e): - return LoopIR.Read(loop1.iter, [], T.index, loop1.srcinfo) + def mk_read(e): + return LoopIR.Read(loop1.iter, [], T.index, loop1.srcinfo) - ir, fwd = proc, lambda x: x - ir, fwd = _replace_reads( - ir, fwd, s_cursor, loop2.iter, mk_read, only_replace_attrs=False - ) - ir, fwd_move = fwd(s_cursor).body()._move(fwd(f_cursor).body()[-1].after()) - fwd = _compose(fwd_move, fwd) - ir, fwdDel = fwd(s_cursor)._delete() - fwd = _compose(fwdDel, fwd) + ir, fwd = proc, lambda x: x + ir, fwd = _replace_reads( + ir, fwd, s_cursor, loop2.iter, mk_read, only_replace_attrs=False + ) + ir, fwd_move = fwd(s_cursor).body()._move(fwd(f_cursor).body()[-1].after()) + fwd = _compose(fwd_move, fwd) + ir, fwdDel = fwd(s_cursor)._delete() + fwd = _compose(fwdDel, fwd) - def static_check(): - if not unsafe_disable_check: + if not unsafe_disable_check and checker == "static": x = LoopIR.Read(loop1.iter, [], T.index, loop1.srcinfo) y = loop2.iter body1 = loop1.body @@ -3133,47 +3049,16 @@ def static_check(): loop = fwd(f_cursor)._node Check_FissionLoop(ir, loop, body1, body2) - def dynamic_check(): - fuzz(f_cursor.as_block().expand(delta_lo=0, delta_hi=1), fwd) + if checker == "dynamic": + fuzz(f_cursor.as_block().expand(delta_lo=0, delta_hi=1), fwd) - do_check(static_check, dynamic_check, check_mode) + return ir, fwd - return ir, fwd + return do_check(check, check_mode) def DoFuseIf(f_cursor, s_cursor, check_mode: CheckMode): - def static_check(): - proc = f_cursor.get_root() - if f_cursor.next() != s_cursor: - raise SchedulingError( - "expected the two if statements to be fused to come one right after the other" - ) - - if1 = f_cursor._node - if2 = s_cursor._node - Check_ExprEqvInContext(proc, if1.cond, [if1], if2.cond, [if2]) - - cond = if1.cond - body1 = if1.body - body2 = if2.body - orelse1 = if1.orelse - orelse2 = if2.orelse - ifstmt = LoopIR.If(cond, body1 + body2, orelse1 + orelse2, if1.srcinfo) - - ir, fwd = s_cursor.body()._move(f_cursor.body()[-1].after()) - if f_cursor.orelse(): - ir, fwd_move = ( - fwd(s_cursor).orelse()._move(fwd(f_cursor).orelse()[-1].after()) - ) - fwd = _compose(fwd_move, fwd) - else: - ir, fwd_repl = fwd(f_cursor).orelse()._replace(orelse1 + orelse2) - fwd = _compose(fwd_repl, fwd) - ir, fwd_del = fwd(s_cursor)._delete() - fwd = _compose(fwd_del, fwd) - return ir, fwd - - def dynamic_check(): + def check(checker: Checker): proc = f_cursor.get_root() if f_cursor.next() != s_cursor: raise SchedulingError( @@ -3182,6 +3067,8 @@ def dynamic_check(): if1 = f_cursor._node if2 = s_cursor._node + if checker == "static": + Check_ExprEqvInContext(proc, if1.cond, [if1], if2.cond, [if2]) cond = if1.cond body1 = if1.body @@ -3201,18 +3088,19 @@ def dynamic_check(): fwd = _compose(fwd_repl, fwd) ir, fwd_del = fwd(s_cursor)._delete() fwd = _compose(fwd_del, fwd) - fuzz(f_cursor.as_block().expand(delta_lo=0, delta_hi=1), fwd) + if checker == "dynamic": + fuzz(f_cursor.as_block().expand(0, 1), fwd(f_cursor).as_block()) return ir, fwd - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoAddLoop(stmt_cursor, var, hi, guard, unsafe_disable_check, check_mode: CheckMode): - def static_check(): + def check(checker: Checker): proc = stmt_cursor.get_root() s = stmt_cursor._node - if not unsafe_disable_check: + if not unsafe_disable_check and checker == "static": Check_IsIdempotent(proc, [s]) Check_IsPositiveExpr(proc, [s], hi) @@ -3235,35 +3123,11 @@ def wrapper(body): ) ir, fwd = stmt_cursor.as_block()._wrap(wrapper, "body") + if checker == "dynamic": + fuzz(stmt_cursor.as_block(), fwd(stmt_cursor).parent().as_block()) return ir, fwd - def dynamic_check(): - proc = stmt_cursor.get_root() - s = stmt_cursor._node - - sym = Sym(var) - - def wrapper(body): - if guard: - rdsym = LoopIR.Read(sym, [], T.index, s.srcinfo) - zero = LoopIR.Const(0, T.int, s.srcinfo) - cond = LoopIR.BinOp("==", rdsym, zero, T.bool, s.srcinfo) - body = [LoopIR.If(cond, body, [], s.srcinfo)] - - return LoopIR.For( - sym, - LoopIR.Const(0, T.index, s.srcinfo), - hi, - body, - LoopIR.Seq(), - s.srcinfo, - ) - - ir, fwd = stmt_cursor.as_block()._wrap(wrapper, "body") - fuzz(stmt_cursor.parent(), fwd) - return ir, fwd - - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) # --------------------------------------------------------------------------- # @@ -3326,52 +3190,58 @@ def err_handler(_, msg): def DoDeleteConfig(proc_cursor, config_cursor, check_mode: CheckMode): - def static_check(): - eq_mod_config = Check_DeleteConfigWrite( - proc_cursor._node, [config_cursor._node] - ) + def check(checker: Checker): + if checker == "static": + eq_mod_config = Check_DeleteConfigWrite( + proc_cursor._node, [config_cursor._node] + ) + else: + eq_mod_config = None p, fwd = config_cursor._delete() + if checker == "dynamic": + scope = config_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + else: + scope = scope.as_block() + fuzz(scope, fwd(scope)) return p, fwd, eq_mod_config - def dynamic_check(): - scope = config_cursor.parent() - if scope.depth() == 0: - scope = scope._child_block("body") - p, fwd = config_cursor._delete() - fuzz(scope, fwd) - return p, fwd, None - - return do_check(static_check, dynamic_check, check_mode) - - -def DoDeleteStmt(proc_cursor, stmt_cursor, check_mode: CheckMode): - def static_check(): - assert False, "check must be done with chexo" + return do_check(check, check_mode) - def dynamic_check(): - scope = stmt_cursor.parent() - if scope.depth() == 0: - scope = scope._child_block("body") - p, fwd = stmt_cursor._delete() - fuzz(scope, fwd) - return p, fwd - return do_check(static_check, dynamic_check, check_mode) +def DoDeleteStmt(stmt_cursor, check_mode: CheckMode): + def check(checker: Checker): + if checker == "static": + assert False, "check must be done with chexo" + else: + scope = stmt_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + else: + scope = scope.as_block() + p, fwd = stmt_cursor._delete() + fuzz(scope, fwd(scope)) + return p, fwd + return do_check(check, check_mode) -def DoInsertStmt(proc_cursor, gap_cursor, new_stmt, check_mode: CheckMode): - def static_check(): - assert False, "check must be done with chexo" - def dynamic_check(): - scope = gap_cursor.parent() - if scope.depth() == 0: - scope = scope._child_block("body") - p, fwd = gap_cursor._insert([new_stmt]) - fuzz(scope, fwd) - return p, fwd, None +def DoInsertStmt(gap_cursor, new_stmt, check_mode: CheckMode): + def check(checker: Checker): + if checker == "static": + assert False, "check must be done with chexo" + else: + scope = gap_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + else: + scope = scope.as_block() + p, fwd = gap_cursor._insert([new_stmt]) + fuzz(scope, fwd(scope)) + return p, fwd - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoDeletePass(proc): @@ -4178,85 +4048,68 @@ def map_s(self, sc): def DoEliminateIfDeadBranch(if_cursor, check_mode: CheckMode): - def static_check(): + def check(checker: Checker): if_stmt = if_cursor._node + original_ir = if_cursor.get_root() assert isinstance(if_stmt, LoopIR.If) - ir, fwd = if_cursor.get_root(), lambda x: x - try: - cond_node = LoopIR.Const(True, T.bool, if_stmt.srcinfo) - Check_ExprEqvInContext(ir, if_stmt.cond, [if_stmt], cond_node) - cond = True + body = if_cursor.body() + ir, fwd = body._move(if_cursor.after()) + ir, fwd_del = fwd(if_cursor)._delete() + fwd = _compose(fwd_del, fwd) + if checker == "static": + cond_node = LoopIR.Const(True, T.bool, if_stmt.srcinfo) + Check_ExprEqvInContext(original_ir, if_stmt.cond, [if_stmt], cond_node) + else: + fuzz(if_cursor.as_block(), fwd(if_cursor.body())) except SchedulingError: try: - cond_node = LoopIR.Const(False, T.bool, if_stmt.srcinfo) - Check_ExprEqvInContext(ir, if_stmt.cond, [if_stmt], cond_node) - cond = False + if checker == "static": + cond_node = LoopIR.Const(False, T.bool, if_stmt.srcinfo) + Check_ExprEqvInContext( + original_ir, if_stmt.cond, [if_stmt], cond_node + ) + else: + body = if_cursor.orelse() + ir, fwd = body._move(if_cursor.after()) + ir, fwd_del = fwd(if_cursor)._delete() + fwd = _compose(fwd_del, fwd) + fuzz(if_cursor.as_block(), fwd(if_cursor.orelse())) except SchedulingError: raise SchedulingError("If condition isn't always True or always False") - body = if_cursor.body() if cond else if_cursor.orelse() - ir, fwd = body._move(if_cursor.after()) - ir, fwd_del = fwd(if_cursor)._delete() - fwd = _compose(fwd_del, fwd) - return ir, fwd - def dynamic_check(): - if_stmt = if_cursor._node - - assert isinstance(if_stmt, LoopIR.If) - - body = if_cursor.body() - ir, fwd = body._move(if_cursor.after()) - ir, fwd_del = fwd(if_cursor)._delete() - fwd = _compose(fwd_del, fwd) - - try: - fuzz(if_cursor.parent(), fwd) - return ir, fwd - except SchedulingError: - body = if_cursor.orelse() - ir, fwd = body._move(if_cursor.after()) - ir, fwd_del = fwd(if_cursor)._delete() - fwd = _compose(fwd_del, fwd) - fuzz(if_cursor.parent(), fwd) - return ir, fwd - - return do_check(static_check, dynamic_check, "static") + return do_check(check, "static") # eliminate dead branch broken on chexo def DoEliminateDeadLoop(loop_cursor, check_mode: CheckMode): - def static_check(): + def check(checker: Checker): loop_stmt = loop_cursor._node assert isinstance(loop_stmt, LoopIR.For) - ir, fwd = loop_cursor.get_root(), lambda x: x - - try: - Check_CompareExprs(ir, [loop_stmt], loop_stmt.lo, ">=", loop_stmt.hi) - except SchedulingError: - raise SchedulingError("Loop condition isn't always False") - - ir, fwd_del = loop_cursor._delete() - - return ir, fwd_del - - def dynamic_check(): - loop_stmt = loop_cursor._node - - assert isinstance(loop_stmt, LoopIR.For) + if checker == "static": + ir = loop_cursor.get_root() + try: + Check_CompareExprs(ir, [loop_stmt], loop_stmt.lo, ">=", loop_stmt.hi) + except SchedulingError: + raise SchedulingError("Loop condition isn't always False") ir, fwd_del = loop_cursor._delete() - - fuzz(loop_cursor.parent(), fwd_del) + if checker == "dynamic": + scope = loop_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + else: + scope = scope.as_block() + fuzz(scope, fwd_del(scope)) return ir, fwd_del - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def DoEliminateDeadCode(stmt_cursor, check_mode: CheckMode): @@ -4281,34 +4134,16 @@ def DoDeleteBuffer(buf_cursor): def DoReuseBuffer(buf_cursor, rep_cursor, check_mode): - assert isinstance(buf_cursor._node, LoopIR.Alloc) - assert isinstance(rep_cursor._node, LoopIR.Alloc) - assert buf_cursor._node.type == rep_cursor._node.type - - buf_name = buf_cursor._node.name - buf_dims = len(buf_cursor._node.type.shape()) - rep_name = rep_cursor._node.name - first_assn = True - - def static_check(): - ir, fwd = rep_cursor._delete() - - def mk_read(c): - return {"name": buf_name} - - def mk_write(c): - nonlocal first_assn - if first_assn: - first_assn = False - Check_IsDeadAfter(buf_cursor.get_root(), [c._node], buf_name, buf_dims) - return {"name": buf_name} + def check(checker: Checker): + assert isinstance(buf_cursor._node, LoopIR.Alloc) + assert isinstance(rep_cursor._node, LoopIR.Alloc) + assert buf_cursor._node.type == rep_cursor._node.type - for c in get_rest_of_block(rep_cursor): - ir, fwd = _replace_reads(ir, fwd, c, rep_name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, rep_name, mk_write) - return ir, fwd + buf_name = buf_cursor._node.name + buf_dims = len(buf_cursor._node.type.shape()) + rep_name = rep_cursor._node.name + first_assn = True - def dynamic_check(): ir, fwd = rep_cursor._delete() def mk_read(c): @@ -4318,15 +4153,20 @@ def mk_write(c): nonlocal first_assn if first_assn: first_assn = False + if checker == "static": + Check_IsDeadAfter( + buf_cursor.get_root(), [c._node], buf_name, buf_dims + ) return {"name": buf_name} for c in get_rest_of_block(rep_cursor): ir, fwd = _replace_reads(ir, fwd, c, rep_name, mk_read) ir, fwd = _replace_writes(ir, fwd, c, rep_name, mk_write) - fuzz(get_rest_of_block(rep_cursor), fwd) + if checker == "dynamic": + fuzz(get_rest_of_block(rep_cursor), fwd(get_rest_of_block(rep_cursor))) return ir, fwd - return do_check(static_check, dynamic_check, check_mode) + return do_check(check, check_mode) def index_range_analysis_wrapper(expr: LoopIR.expr) -> IndexRange: @@ -4463,11 +4303,12 @@ def do_e(self, e): def DoFoldBuffer(alloc_cursor, dim_idx, new_size, check_mode: CheckMode): - alloc_name = alloc_cursor._node.name + def check(checker: Checker): + alloc_name = alloc_cursor._node.name - def static_check(): - buffer_check = CheckFoldBuffer(alloc_name, dim_idx, new_size) - buffer_check.do_stmts([c._node for c in get_rest_of_block(alloc_cursor)]) + if checker == "static": + buffer_check = CheckFoldBuffer(alloc_name, dim_idx, new_size) + buffer_check.do_stmts([c._node for c in get_rest_of_block(alloc_cursor)]) size_expr = LoopIR.Const(new_size, T.index, alloc_cursor._node.srcinfo) ir, fwd = ( @@ -4517,337 +4358,345 @@ def mk_write(c): new_alloc_cursor = fwd(alloc_cursor) after_alloc = [c._node for c in get_rest_of_block(new_alloc_cursor)] - Check_Bounds(ir, new_alloc_cursor._node, after_alloc) + if checker == "static": + Check_Bounds(ir, new_alloc_cursor._node, after_alloc) + else: + fuzz(get_rest_of_block(alloc_cursor), fwd(get_rest_of_block(alloc_cursor))) return ir, fwd - def dynamic_check(): - size_expr = LoopIR.Const(new_size, T.index, alloc_cursor._node.srcinfo) - ir, fwd = ( - alloc_cursor._child_node("type") - ._child_block("hi")[dim_idx] - ._replace([size_expr]) - ) - - def make_index_mod(e): - return LoopIR.BinOp("%", e, size_expr, T.index, e.srcinfo) - - def mk_read(c): - rd = c._node - new_idx = rd.idx.copy() - if isinstance(rd, LoopIR.Read): - new_idx[dim_idx] = make_index_mod(rd.idx[dim_idx]) - return {"idx": new_idx} - - elif isinstance(rd, LoopIR.WindowExpr): - if isinstance(rd.idx[dim_idx], LoopIR.Point): - new_idx[dim_idx] = LoopIR.Point( - make_index_mod(rd.idx[dim_idx].pt), rd.srcinfo - ) - else: - # TODO: see if check_bounds catches the case where lo, hi spans a multiple - # of size, which would break the buffer folding - new_idx[dim_idx] = LoopIR.Interval( - make_index_mod(rd.idx[dim_idx].lo), - make_index_mod(rd.idx[dim_idx].hi), - rd.srcinfo, - ) + return do_check(check, check_mode) - return {"idx": new_idx} - else: - raise NotImplementedError(f"Did not implement {type(rd)}.") - def mk_write(c): - s = c._node - new_idx = s.idx.copy() - new_idx[dim_idx] = make_index_mod(s.idx[dim_idx]) - return {"idx": new_idx} +def DoStageMem( + block_cursor, buf_name_str, w_exprs, new_name, check_mode, use_accum_zero=False +): + new_name = Sym(new_name) - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) + def check(checker: Checker): + def get_typ_mem(): + syms_env = extract_env(block_cursor[0]) + for name, typ, mem in syms_env: + if str(name) == buf_name_str: + return name, typ, mem + assert False, "Must find the symbol in env" - new_alloc_cursor = fwd(alloc_cursor) - after_alloc = [c._node for c in get_rest_of_block(new_alloc_cursor)] + buf_name, buf_typ, mem = get_typ_mem() + buf_typ = buf_typ if not isinstance(buf_typ, T.Window) else buf_typ.as_tensor - fuzz(get_rest_of_block(alloc_cursor), fwd) - return ir, fwd + if len(w_exprs) != len(buf_typ.shape()): + raise SchedulingError( + f"expected windowing of '{buf_name}' " + f"to have {len(buf_typ.shape())} indices, " + f"but only got {len(w_exprs)}" + ) - return do_check(static_check, dynamic_check, check_mode) + shape = [ + LoopIR.BinOp("-", w[1], w[0], T.index, w[0].srcinfo) + for w in w_exprs + if isinstance(w, tuple) + ] + if all(isinstance(w, LoopIR.expr) for w in w_exprs): + new_typ = buf_typ.basetype() + else: + new_typ = T.Tensor(shape, False, buf_typ.basetype()) + + def rewrite_idx(idx): + assert len(idx) == len(w_exprs) + return [ + LoopIR.BinOp("-", i, w[0], T.index, i.srcinfo) + for i, w in zip(idx, w_exprs) + if isinstance(w, tuple) + ] + def rewrite_win(w_idx): + assert len(w_idx) == len(w_exprs) -def DoStageMem(block_cursor, buf_name, w_exprs, new_name, use_accum_zero=False): - new_name = Sym(new_name) + def off_w(w, off): + if isinstance(w, LoopIR.Interval): + lo = LoopIR.BinOp("-", w.lo, off, T.index, w.srcinfo) + hi = LoopIR.BinOp("-", w.hi, off, T.index, w.srcinfo) + return LoopIR.Interval(lo, hi, w.srcinfo) + else: + assert isinstance(w, LoopIR.Point) + pt = LoopIR.BinOp("-", w.pt, off, T.index, w.srcinfo) + return LoopIR.Point(pt, w.srcinfo) - def get_typ_mem(): - syms_env = extract_env(block_cursor[0]) - for name, typ, mem in syms_env: - if str(name) == buf_name: - return name, typ, mem - assert False, "Must find the symbol in env" + w_los = [w_e[0] if isinstance(w_e, tuple) else w_e for w_e in w_exprs] - buf_name, buf_typ, mem = get_typ_mem() - buf_typ = buf_typ if not isinstance(buf_typ, T.Window) else buf_typ.as_tensor + return [off_w(w_i, w_e) for w_i, w_e in zip(w_idx, w_los)] - if len(w_exprs) != len(buf_typ.shape()): - raise SchedulingError( - f"expected windowing of '{buf_name}' " - f"to have {len(buf_typ.shape())} indices, " - f"but only got {len(w_exprs)}" - ) + ir = block_cursor.get_root() + block = [s._node for s in block_cursor] + if use_accum_zero: + n_dims = len(buf_typ.shape()) + if checker == "static": + Check_BufferReduceOnly( + ir, + block, + buf_name, + n_dims, + ) - shape = [ - LoopIR.BinOp("-", w[1], w[0], T.index, w[0].srcinfo) - for w in w_exprs - if isinstance(w, tuple) - ] - if all(isinstance(w, LoopIR.expr) for w in w_exprs): - new_typ = buf_typ.basetype() - else: - new_typ = T.Tensor(shape, False, buf_typ.basetype()) + n_dims = len(buf_typ.shape()) + basetyp = new_typ.basetype() if isinstance(new_typ, T.Tensor) else new_typ + srcinfo = block[0].srcinfo - def rewrite_idx(idx): - assert len(idx) == len(w_exprs) - return [ - LoopIR.BinOp("-", i, w[0], T.index, i.srcinfo) - for i, w in zip(idx, w_exprs) - if isinstance(w, tuple) - ] + new_alloc = [LoopIR.Alloc(new_name, new_typ, mem, srcinfo)] + ir, fwd = block_cursor[0].before()._insert(new_alloc) - def rewrite_win(w_idx): - assert len(w_idx) == len(w_exprs) + def get_inner_stmt(loop_nest_c): + node = loop_nest_c._node + if not isinstance(node, LoopIR.For): + return loop_nest_c + return get_inner_stmt(loop_nest_c.body()[0]) - def off_w(w, off): - if isinstance(w, LoopIR.Interval): - lo = LoopIR.BinOp("-", w.lo, off, T.index, w.srcinfo) - hi = LoopIR.BinOp("-", w.hi, off, T.index, w.srcinfo) - return LoopIR.Interval(lo, hi, w.srcinfo) + if checker == "dynamic": + coverage_summary = fuzz_single_scope( + block_cursor, + StageMemArgs(buf_name, StagedWindowExpr(w_exprs), block_cursor), + ) + else: + coverage_summary = None + + # Insert guards to ensure load/store stages don't access out of bounds + def insert_safety_guards(ir, fwd, ctxt_stmt_c, access, buf_typ): + def check_cond(cond): + ctxt_stmt = ctxt_stmt_c._node + true_node = LoopIR.Const(True, T.bool, ctxt_stmt.srcinfo) + try: + Check_ExprEqvInContext(ir, cond, [ctxt_stmt], true_node) + return True + except SchedulingError: + return False + + # Get a list of lower/upper bound on the index accesses + const_0 = LoopIR.Const(0, T.int, access.srcinfo) + conds = [] + if coverage_summary is None: + for i in zip(access.idx, buf_typ.shape()): + lower_bound_cond = LoopIR.BinOp( + "<=", const_0, i[0], T.bool, access.srcinfo + ) + if not check_cond(lower_bound_cond): + conds.append(lower_bound_cond) + upper_bound_cond = LoopIR.BinOp( + "<", i[0], i[1], T.bool, access.srcinfo + ) + if not check_cond(upper_bound_cond): + conds.append(upper_bound_cond) else: - assert isinstance(w, LoopIR.Point) - pt = LoopIR.BinOp("-", w.pt, off, T.index, w.srcinfo) - return LoopIR.Point(pt, w.srcinfo) - - w_los = [w_e[0] if isinstance(w_e, tuple) else w_e for w_e in w_exprs] - - return [off_w(w_i, w_e) for w_i, w_e in zip(w_idx, w_los)] - - ir = block_cursor.get_root() - block = [s._node for s in block_cursor] - if use_accum_zero: - n_dims = len(buf_typ.shape()) - Check_BufferReduceOnly( - ir, - block, - buf_name, - n_dims, - ) + for ( + access_idx, + buf_upper_bound, + guard_upper_bound, + guard_lower_bound, + ) in zip( + access.idx, + buf_typ.shape(), + coverage_summary.stage_mem_result.dim_needs_upper_bound_guard, + coverage_summary.stage_mem_result.dim_needs_lower_bound_guard, + ): + lower_bound_cond = LoopIR.BinOp( + "<=", const_0, access_idx, T.bool, access.srcinfo + ) + if guard_lower_bound: + conds.append(lower_bound_cond) + upper_bound_cond = LoopIR.BinOp( + "<", access_idx, buf_upper_bound, T.bool, access.srcinfo + ) + if guard_upper_bound: + conds.append(upper_bound_cond) - n_dims = len(buf_typ.shape()) - basetyp = new_typ.basetype() if isinstance(new_typ, T.Tensor) else new_typ - srcinfo = block[0].srcinfo + if len(conds) == 0: + return ir, fwd - new_alloc = [LoopIR.Alloc(new_name, new_typ, mem, srcinfo)] - ir, fwd = block_cursor[0].before()._insert(new_alloc) + # Construct the condition + cond = conds[0] + for c in conds[1:]: + cond = LoopIR.BinOp("and", cond, c, T.bool, cond.srcinfo) - def get_inner_stmt(loop_nest_c): - node = loop_nest_c._node - if not isinstance(node, LoopIR.For): - return loop_nest_c - return get_inner_stmt(loop_nest_c.body()[0]) + # Construct the If statement and wrap the context statement + def guard_wrapper(body): + return LoopIR.If(cond, body, [], srcinfo) - # Insert guards to ensure load/store stages don't access out of bounds - def insert_safety_guards(ir, fwd, ctxt_stmt_c, access, buf_typ): - def check_cond(cond): - ctxt_stmt = ctxt_stmt_c._node - true_node = LoopIR.Const(True, T.bool, ctxt_stmt.srcinfo) - try: - Check_ExprEqvInContext(ir, cond, [ctxt_stmt], true_node) - return True - except SchedulingError: - return False + # You want to forward `ctxt_stmt_c` instead of relying on passing + # the forwarded version. However, in all the current callees, the + # statement would have been just constructed and if you try to forward + # you get an error. + ir, fwd_wrap = ctxt_stmt_c.as_block()._wrap(guard_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) - # Get a list of lower/upper bound on the index accesses - const_0 = LoopIR.Const(0, T.int, access.srcinfo) - conds = [] - for i in zip(access.idx, buf_typ.shape()): - lower_bound_cond = LoopIR.BinOp("<=", const_0, i[0], T.bool, access.srcinfo) - if not check_cond(lower_bound_cond): - conds.append(lower_bound_cond) - upper_bound_cond = LoopIR.BinOp("<", i[0], i[1], T.bool, access.srcinfo) - if not check_cond(upper_bound_cond): - conds.append(upper_bound_cond) - - if len(conds) == 0: return ir, fwd - # Construct the condition - cond = conds[0] - for c in conds[1:]: - cond = LoopIR.BinOp("and", cond, c, T.bool, cond.srcinfo) + def idx_contained_by_window(idx, block_cursor): + """ + Returns True if idx always lies in staged window range. + Returns False if idx never lies in staged window range. + Otherwise, will raise a SchedulingError. + """ + if coverage_summary is None: + p = idx.get_root() + return Check_Access_In_Window(p, idx, w_exprs, block_cursor) + else: + return ( + idx.get_path() + in coverage_summary.stage_mem_result.overlapping_accesses + ) - # Construct the If statement and wrap the context statement - def guard_wrapper(body): - return LoopIR.If(cond, body, [], srcinfo) + actualR = actualW = False + WShadow = False + # Conservatively, shadowing logic only works for single element staging windows. + w_is_pt = all(not isinstance(w, tuple) for w in w_exprs) - # You want to forward `ctxt_stmt_c` instead of relying on passing - # the forwarded version. However, in all the current callees, the - # statement would have been just constructed and if you try to forward - # you get an error. - ir, fwd_wrap = ctxt_stmt_c.as_block()._wrap(guard_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) + def mk_read(c, block_cursor): + nonlocal actualR + rd = c._node - return ir, fwd + if isinstance(rd, LoopIR.Read): + if idx_contained_by_window(c, block_cursor): + _idx = rewrite_idx(rd.idx) + actualR = True + return {"name": new_name, "idx": _idx} + elif isinstance(rd, LoopIR.WindowExpr): + if any( + isinstance(w, LoopIR.Interval) and not isinstance(w_e, tuple) + for w, w_e in zip(rd.idx, w_exprs) + ): + raise SchedulingError( + f"Existing WindowExpr {rd} has a widnowed dimension which is not windowed in the new staged window." + ) - def idx_contained_by_window(idx, block_cursor): - """ - Returns True if idx always lies in staged window range. - Returns False if idx never lies in staged window range. - Otherwise, will raise a SchedulingError. - """ - p = idx.get_root() - return Check_Access_In_Window(p, idx, w_exprs, block_cursor) + if idx_contained_by_window(c, block_cursor): + _idx = rewrite_win(rd.idx) + _typ = T.Window(new_typ, rd.type.as_tensor, new_name, _idx) + actualR = True + return {"name": new_name, "idx": _idx, "type": _typ} - actualR = actualW = False - WShadow = False - # Conservatively, shadowing logic only works for single element staging windows. - w_is_pt = all(not isinstance(w, tuple) for w in w_exprs) + def mk_write(c, block_cursor): + nonlocal actualR + nonlocal actualW + nonlocal WShadow + s = c._node + if isinstance(s, (LoopIR.Assign, LoopIR.Reduce)): + if idx_contained_by_window(c, block_cursor): + actualW = True + if isinstance(s, LoopIR.Reduce): + actualR = True + if not actualR and w_is_pt: + WShadow = True + return {"name": new_name, "idx": rewrite_idx(s.idx)} + + for c in block_cursor: + ir, fwd = _replace_reads( + ir, fwd, c, buf_name, partial(mk_read, block_cursor=fwd(block_cursor)) + ) - def mk_read(c, block_cursor): - nonlocal actualR - rd = c._node + ir, fwd = _replace_writes( + ir, fwd, c, buf_name, partial(mk_write, block_cursor=fwd(block_cursor)) + ) - if isinstance(rd, LoopIR.Read): - if idx_contained_by_window(c, block_cursor): - _idx = rewrite_idx(rd.idx) - actualR = True - return {"name": new_name, "idx": _idx} - elif isinstance(rd, LoopIR.WindowExpr): - if any( - isinstance(w, LoopIR.Interval) and not isinstance(w_e, tuple) - for w, w_e in zip(rd.idx, w_exprs) - ): - raise SchedulingError( - f"Existing WindowExpr {rd} has a widnowed dimension which is not windowed in the new staged window." + if actualR and not WShadow: + load_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)] + load_widx = [LoopIR.Read(s, [], T.index, srcinfo) for s in load_iter] + if use_accum_zero: + load_rhs = LoopIR.Const(0.0, basetyp, srcinfo) + else: + cp_load_widx = load_widx.copy() + load_ridx = [] + for w in w_exprs: + if isinstance(w, tuple): + load_ridx.append( + LoopIR.BinOp( + "+", cp_load_widx.pop(0), w[0], T.index, srcinfo + ) + ) + else: + load_ridx.append(w) + load_rhs = LoopIR.Read(buf_name, load_ridx, basetyp, srcinfo) + + load_nest = [LoopIR.Assign(new_name, basetyp, load_widx, load_rhs, srcinfo)] + + for i, n in reversed(list(zip(load_iter, shape))): + loop = LoopIR.For( + i, + LoopIR.Const(0, T.index, srcinfo), + n, + load_nest, + LoopIR.Seq(), + srcinfo, ) + load_nest = [loop] - if idx_contained_by_window(c, block_cursor): - _idx = rewrite_win(rd.idx) - _typ = T.Window(new_typ, rd.type.as_tensor, new_name, _idx) - actualR = True - return {"name": new_name, "idx": _idx, "type": _typ} - - def mk_write(c, block_cursor): - nonlocal actualR - nonlocal actualW - nonlocal WShadow - s = c._node - if isinstance(s, (LoopIR.Assign, LoopIR.Reduce)): - if idx_contained_by_window(c, block_cursor): - actualW = True - if isinstance(s, LoopIR.Reduce): - actualR = True - if not actualR and w_is_pt: - WShadow = True - return {"name": new_name, "idx": rewrite_idx(s.idx)} - - for c in block_cursor: - ir, fwd = _replace_reads( - ir, fwd, c, buf_name, partial(mk_read, block_cursor=fwd(block_cursor)) - ) + ir, fwd_ins = fwd(block_cursor[0]).before()._insert(load_nest) + fwd = _compose(fwd_ins, fwd) - ir, fwd = _replace_writes( - ir, fwd, c, buf_name, partial(mk_write, block_cursor=fwd(block_cursor)) - ) + if not use_accum_zero: + load_nest_c = fwd(block_cursor[0]).prev() + ir, fwd = insert_safety_guards( + ir, fwd, get_inner_stmt(load_nest_c), load_rhs, buf_typ + ) - if actualR and not WShadow: - load_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)] - load_widx = [LoopIR.Read(s, [], T.index, srcinfo) for s in load_iter] - if use_accum_zero: - load_rhs = LoopIR.Const(0.0, basetyp, srcinfo) - else: - cp_load_widx = load_widx.copy() - load_ridx = [] + if actualW: + store_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)] + store_ridx = [LoopIR.Read(s, [], T.index, srcinfo) for s in store_iter] + cp_store_ridx = store_ridx.copy() + store_widx = [] for w in w_exprs: if isinstance(w, tuple): - load_ridx.append( - LoopIR.BinOp("+", cp_load_widx.pop(0), w[0], T.index, srcinfo) + store_widx.append( + LoopIR.BinOp("+", cp_store_ridx.pop(0), w[0], T.index, srcinfo) ) else: - load_ridx.append(w) - load_rhs = LoopIR.Read(buf_name, load_ridx, basetyp, srcinfo) - - load_nest = [LoopIR.Assign(new_name, basetyp, load_widx, load_rhs, srcinfo)] - - for i, n in reversed(list(zip(load_iter, shape))): - loop = LoopIR.For( - i, - LoopIR.Const(0, T.index, srcinfo), - n, - load_nest, - LoopIR.Seq(), - srcinfo, - ) - load_nest = [loop] + store_widx.append(w) + + store_rhs = LoopIR.Read(new_name, store_ridx, basetyp, srcinfo) + store_stmt = LoopIR.Reduce if use_accum_zero else LoopIR.Assign + store_nest = [store_stmt(buf_name, basetyp, store_widx, store_rhs, srcinfo)] + + for i, n in reversed(list(zip(store_iter, shape))): + loop = LoopIR.For( + i, + LoopIR.Const(0, T.index, srcinfo), + n, + store_nest, + LoopIR.Seq(), + srcinfo, + ) + store_nest = [loop] - ir, fwd_ins = fwd(block_cursor[0]).before()._insert(load_nest) - fwd = _compose(fwd_ins, fwd) + ir, fwd_ins = fwd(block_cursor[-1]).after()._insert(store_nest) + fwd = _compose(fwd_ins, fwd) - if not use_accum_zero: - load_nest_c = fwd(block_cursor[0]).prev() + store_nest_c = fwd(block_cursor[-1]).next() + store_stmt_c = get_inner_stmt(store_nest_c) ir, fwd = insert_safety_guards( - ir, fwd, get_inner_stmt(load_nest_c), load_rhs, buf_typ + ir, fwd, store_stmt_c, store_stmt_c._node, buf_typ ) - if actualW: - store_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)] - store_ridx = [LoopIR.Read(s, [], T.index, srcinfo) for s in store_iter] - cp_store_ridx = store_ridx.copy() - store_widx = [] - for w in w_exprs: - if isinstance(w, tuple): - store_widx.append( - LoopIR.BinOp("+", cp_store_ridx.pop(0), w[0], T.index, srcinfo) - ) - else: - store_widx.append(w) - - store_rhs = LoopIR.Read(new_name, store_ridx, basetyp, srcinfo) - store_stmt = LoopIR.Reduce if use_accum_zero else LoopIR.Assign - store_nest = [store_stmt(buf_name, basetyp, store_widx, store_rhs, srcinfo)] - - for i, n in reversed(list(zip(store_iter, shape))): - loop = LoopIR.For( - i, - LoopIR.Const(0, T.index, srcinfo), - n, - store_nest, - LoopIR.Seq(), - srcinfo, + # new alloc, load_nest + new_body + store_nest + new_block_c = fwd(block_cursor[0]).as_block().expand(0, len(block_cursor) - 1) + if actualR and not WShadow: + new_block_c = new_block_c.expand(1, 0) + if actualW: + new_block_c = new_block_c.expand(0, 1) + if not actualR and not actualW: + raise SchedulingError( + f"Cannot stage '{buf_name}' with the given window shape. Wrong window shape, or '{buf_name}' not accessed in the given scope?" ) - store_nest = [loop] - - ir, fwd_ins = fwd(block_cursor[-1]).after()._insert(store_nest) - fwd = _compose(fwd_ins, fwd) - store_nest_c = fwd(block_cursor[-1]).next() - store_stmt_c = get_inner_stmt(store_nest_c) - ir, fwd = insert_safety_guards( - ir, fwd, store_stmt_c, store_stmt_c._node, buf_typ - ) - - # new alloc, load_nest + new_body + store_nest - new_block_c = fwd(block_cursor[0]).as_block().expand(0, len(block_cursor) - 1) - if actualR and not WShadow: - new_block_c = new_block_c.expand(1, 0) - if actualW: - new_block_c = new_block_c.expand(0, 1) - if not actualR and not actualW: - raise SchedulingError( - f"Cannot stage '{buf_name}' with the given window shape. Wrong window shape, or '{buf_name}' not accessed in the given scope?" - ) + if checker == "static": + Check_Bounds(ir, new_alloc[0], [c._node for c in new_block_c]) + else: + scope = block_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + else: + scope = scope.as_block() + fuzz(scope, fwd(scope)) - Check_Bounds(ir, new_alloc[0], [c._node for c in new_block_c]) + return ir, fwd - return ir, fwd + return do_check(check, check_mode) def DoUnrollBuffer(alloc_cursor, dim): diff --git a/src/exo/rewrite/chexo/LoopIR_transpiler.py b/src/exo/rewrite/chexo/LoopIR_transpiler.py index 999b59a5c..a70a6704c 100644 --- a/src/exo/rewrite/chexo/LoopIR_transpiler.py +++ b/src/exo/rewrite/chexo/LoopIR_transpiler.py @@ -317,36 +317,43 @@ def enter_stmt(self, stmt_node: Node): ): if isinstance(symbolic_parent_dim, SymbolicSlice): assert isinstance(js_parent_dim, Slice) - out_of_bounds_sym = Sym("oob") - js_out_of_bounds_cond = ( - "&&".join( - ( - f"({js_staged_dim.index}<{js_parent_dim.upper_bound})", - f"({js_parent_dim.lower_bound}<={js_staged_dim.index})", - ) - ) + upper_bound_violation_sym = Sym("ub") + lower_bound_violation_sym = Sym("lb") + js_upper_bound_violation_cond = ( + f"({js_staged_dim.index}>={js_parent_dim.upper_bound})" if isinstance(js_staged_dim, Point) - else "&&".join( - ( - f"({js_parent_dim.lower_bound}<={js_staged_dim.lower_bound})", - f"({js_staged_dim.upper_bound})<{js_parent_dim.upper_bound})", - ) - ) + else f"({js_staged_dim.upper_bound})>{js_parent_dim.upper_bound})" + ) + js_lower_bound_violation_cond = ( + f"({js_parent_dim.lower_bound}<={js_staged_dim.index})" + if isinstance(js_staged_dim, Point) + else f"({js_parent_dim.lower_bound}<={js_staged_dim.lower_bound})" ) self.bound_checks.append( StagingBoundCheck( - out_of_bounds_sym, + upper_bound_violation_sym, + lower_bound_violation_sym, symbolic_staged_dim, symbolic_parent_dim, self.parent_state.current_node, ( IndexedFiller( self.parent_state.cov_placeholder, - f"let {repr(out_of_bounds_sym)}=false;", + "".join( + ( + f"let {repr(upper_bound_violation_sym)}=false;", + f"let {repr(lower_bound_violation_sym)}=false;", + ) + ), ), IndexedFiller( stage_placeholder, - f"if({js_out_of_bounds_cond}){{let {repr(out_of_bounds_sym)}=true;}}", + "".join( + ( + f"if({js_upper_bound_violation_cond}){{let {repr(upper_bound_violation_sym)}=true;}}", + f"if({js_lower_bound_violation_cond}){{let {repr(lower_bound_violation_sym)}=true;}}", + ) + ), ), ), ) @@ -517,10 +524,10 @@ def access_tensor( "".join( ( f"{repr(access_set_sym)}{'_cw' if is_write else '_cr'}.add({js_access});", - f"if({repr(access_set_sym)}_pw.has({js_access})){{{repr(self.coverage_sym)}=true}}", + f"if({repr(access_set_sym)}_pw.has({js_access})){{{repr(self.coverage_sym)}=true;return [1,{CONTEXT_OBJECT_NAME},{{}}];}}", *( ( - f"if({repr(access_set_sym)}_pr.has({js_access})){{{repr(self.coverage_sym)}=true}}", + f"if({repr(access_set_sym)}_pr.has({js_access})){{{repr(self.coverage_sym)}=true;return [1,{CONTEXT_OBJECT_NAME},{{}}];}}", ) if is_write else () diff --git a/src/exo/rewrite/chexo/chexo.py b/src/exo/rewrite/chexo/chexo.py index 54dc89c8e..843cb68ad 100644 --- a/src/exo/rewrite/chexo/chexo.py +++ b/src/exo/rewrite/chexo/chexo.py @@ -1,8 +1,6 @@ -from itertools import chain -import time -from typing import Callable, Literal, Optional, Union +from typing import Literal, Optional, Union -from ...core.internal_cursors import Cursor, Block, Node, NodePath +from ...core.internal_cursors import Block, Node, NodePath from .LoopIR_transpiler import CoverageArgs, StageMemArgs, Transpiler from .coverage import CoverageSkeleton @@ -14,16 +12,11 @@ from ...core.prelude import Sym, SrcInfo from ...core.memory import DRAM, Memory import numpy as np -from ..new_eff import SchedulingError +from ..analysis import SchedulingError from .constraint_solver import ( TRUE_CONSTRAINT, - Constraint, - ConstraintClause, ConstraintMaker, - ConstraintTerm, DisjointConstraint, - Expression, - Solution, ) from pythonmonkey import eval as js_eval @@ -603,15 +596,26 @@ def get_test_spec( MAX_ITERS = 20 +@dataclass +class StageMemResult: + overlapping_accesses: frozenset[NodePath] + dim_needs_upper_bound_guard: tuple[bool, ...] + dim_needs_lower_bound_guard: tuple[bool, ...] + + +@dataclass +class CoverageSummary: + stage_mem_result: StageMemResult + + def fuzz( scope1: Block, scope2: Block, - staging_args: Optional[StageMemArgs] = None, ): """ + fuzz to determine if behavior of two programs are equivalent scope1: smallest scope containing all changes made by scheduling op in original program scope2: scope corresponding to starting scope in transformed program - staging_args: arguments to stage_mem scheduling op """ cur_scope1 = TestScope(scope1) cur_scope2 = TestScope(scope2) @@ -632,7 +636,6 @@ def fuzz( cm, spec1.var_renaming, spec1.forward_to_test(scope1), - spec1.forward_staging_args(staging_args), ), ) transpiled_test2 = Transpiler( @@ -641,7 +644,6 @@ def fuzz( cm, spec2.var_renaming, spec2.forward_to_test(scope2), - spec2.forward_staging_args(staging_args), ), ) @@ -711,3 +713,106 @@ def fuzz( cur_scope1 = cur_scope1.broaden() cur_scope2 = cur_scope2.broaden() raise SchedulingError("tests failed at broadest scope") + + +def fuzz_single_scope( + scope: Block, staging_args: Optional[StageMemArgs] = None +) -> CoverageSummary: + """ + fuzz program to determine properties besides program equivalence + e.g. stage_mem which determines whether accesses can overlap or parallel loop checks + scope: smallest scope containing properties that need to be fuzzed for + staging_args: arguments to stage_mem scheduling op + """ + cur_scope = TestScope(scope) + cur_type_map = cur_scope.get_type_map() + + while cur_scope is not None: + cm = ConstraintMaker(cur_type_map) + + spec = cur_scope.get_test_spec(cm, cur_type_map) + + transpiled_test = Transpiler( + # new proc that contains the current scope as a body, not the entire proc + spec.proc, + CoverageArgs( + cm, + spec.var_renaming, + spec.forward_to_test(scope), + spec.forward_staging_args(staging_args), + ), + ) + + config_fields = transpiled_test.get_configs() + + arg_types = spec.arg_types + # precondition of current scope in both original and transformed program + constraint = spec.constraint + + # symbolic representation of control flow in both original and transformed scope + coverage_skeleton = transpiled_test.get_coverage_skeleton() + assert coverage_skeleton is not None + tests_passed = True + skipped_tests = 0 + iters = 0 + while ( + not coverage_skeleton.get_coverage_progress().is_finished() + and iters < MAX_ITERS + and tests_passed + ): + test_case = generate_test_case( + arg_types, + config_fields, + constraint, + coverage_skeleton, + cm, + ) + # if constraint is unsolvable + if test_case is None: + skipped_tests += 1 + if skipped_tests > MAX_SKIPPED_TESTS: + # program should pass but not testing it is probably bad + assert False + else: + continue + + out = run_test_case(test_case, transpiled_test) + if out == "failed": + # precondition in called subproc failed or out of bounds access + tests_passed = False + break + assert out.coverage_result is not None + coverage_skeleton.update_coverage(out.coverage_result) + iters += 1 + if tests_passed: + overlapping_accesses = [] + invalid_staging = False + for staging_overlap in coverage_skeleton.staging_overlaps: + if staging_overlap.has_overlap: + if staging_overlap.has_disjoint_access: + invalid_staging = True + break + else: + overlapping_accesses.append( + spec.backward_from_test(staging_overlap.access_cursor) + ) + if invalid_staging: + cur_scope = cur_scope.broaden() + continue + else: + return CoverageSummary( + StageMemResult( + frozenset(overlapping_accesses), + tuple( + staging_bound_check.violated_upper_bound + for staging_bound_check in coverage_skeleton.staging_bound_checks + ), + tuple( + staging_bound_check.violated_lower_bound + for staging_bound_check in coverage_skeleton.staging_bound_checks + ), + ) + ) + else: + cur_scope = cur_scope.broaden() + raise SchedulingError("tests failed at broadest scope") diff --git a/src/exo/rewrite/chexo/constraint_solver.py b/src/exo/rewrite/chexo/constraint_solver.py index be53ac8ce..971cdf0d7 100644 --- a/src/exo/rewrite/chexo/constraint_solver.py +++ b/src/exo/rewrite/chexo/constraint_solver.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Callable, Literal, Union, Optional from ...core.configs import Config diff --git a/src/exo/rewrite/chexo/coverage.py b/src/exo/rewrite/chexo/coverage.py index 5d03059f2..e6fd80e37 100644 --- a/src/exo/rewrite/chexo/coverage.py +++ b/src/exo/rewrite/chexo/coverage.py @@ -1,19 +1,18 @@ from dataclasses import dataclass, field -from itertools import groupby -from typing import Generator, Iterable, Optional, Union +from typing import Generator, Optional, Union import numpy as np from .constraint_solver import ( Constraint, ConstraintMaker, - ConstraintTerm, DisjointConstraint, TRUE_CONSTRAINT, Expression, Solution, ) from ...core.prelude import Sym -from ...core.internal_cursors import Node, NodePath +from ...core.internal_cursors import NodePath +from ..analysis import SchedulingError @dataclass @@ -406,67 +405,91 @@ def rename_syms(self, lookup: dict[Sym, Sym]) -> "SymbolicSlice": @dataclass class StagingBoundCheck: - out_of_bounds_sym: Sym + upper_bound_violation_sym: Sym + lower_bound_violation_sym: Sym staged_index: SymbolicWindowIndex parent_index: SymbolicSlice node: CoverageSkeletonNode indexed_fillers: tuple[IndexedFiller, ...] - had_out_of_bounds: bool = False - visited_out_of_bounds: bool = False + violated_upper_bound: bool = False + violated_lower_bound: bool = False + visited_upper_bound_violation: bool = False + visited_lower_bound_violation: bool = False def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: for indexed_filler in self.indexed_fillers: yield indexed_filler def get_coverage_syms(self) -> frozenset[Sym]: - return frozenset((self.out_of_bounds_sym,)) + return frozenset( + (self.upper_bound_violation_sym, self.lower_bound_violation_sym) + ) def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): - out_of_bounds = coverage_result[repr(self.out_of_bounds_sym)] - assert isinstance(out_of_bounds, bool) - self.had_out_of_bounds |= out_of_bounds - self.visited_out_of_bounds |= out_of_bounds + violated_upper_bound = coverage_result[repr(self.upper_bound_violation_sym)] + violated_lower_bound = coverage_result[repr(self.lower_bound_violation_sym)] + assert isinstance(violated_upper_bound, bool) and isinstance( + violated_lower_bound, bool + ) + self.violated_upper_bound |= violated_upper_bound + self.violated_lower_bound |= violated_lower_bound def get_coverage_progress(self) -> CoverageProgress: return CoverageProgress( - (1 if self.visited_out_of_bounds else 0), - 1, + (1 if self.visited_upper_bound_violation else 0) + + (1 if self.visited_lower_bound_violation else 0), + 2, ) def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: - if not self.visited_out_of_bounds: + if not self.visited_upper_bound_violation: out_of_bounds_cond = ( Constraint( self.staged_index.index.add(Expression.from_constant(1)) .negate() .add(self.parent_index.upper_bound), True, - ) - .lift_to_disjoint_constraint() - .intersect( - Constraint( - self.staged_index.index.add( - self.parent_index.lower_bound.negate() - ), - True, - ).lift_to_disjoint_constraint() - ) + ).lift_to_disjoint_constraint() if isinstance(self.staged_index, SymbolicPoint) else Constraint( self.staged_index.upper_bound.negate().add( self.parent_index.upper_bound ), True, - ) - .lift_to_disjoint_constraint() - .intersect( - Constraint( - self.staged_index.lower_bound.add( - self.parent_index.lower_bound.negate() - ), - True, - ).lift_to_disjoint_constraint() - ) + ).lift_to_disjoint_constraint() + ) + path_constraint = self.node.get_complete_constraint().intersect( + out_of_bounds_cond + ) + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms(), + state.free_vars, + ) + new_constraint = state.current_constraint.intersect( + path_constraint.rename_syms(sym_renaming) + ) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + self.visited_upper_bound_violation = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + if not self.visited_lower_bound_violation: + out_of_bounds_cond = ( + Constraint( + self.staged_index.index.add(self.parent_index.lower_bound.negate()), + True, + ).lift_to_disjoint_constraint() + if isinstance(self.staged_index, SymbolicPoint) + else Constraint( + self.staged_index.lower_bound.add( + self.parent_index.lower_bound.negate() + ), + True, + ).lift_to_disjoint_constraint() ) path_constraint = self.node.get_complete_constraint().intersect( out_of_bounds_cond @@ -484,7 +507,7 @@ def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: if ( new_solution is None and state.is_base_constraint ) or new_solution is not None: - self.visited_out_of_bounds = True + self.visited_lower_bound_violation = True if new_solution is not None: return state.update_solution(new_constraint, new_solution) return state From f33fa2c9dee5e5886eeda6a74a36fb019ac676fd Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 27 May 2025 16:40:04 -0400 Subject: [PATCH 23/24] chexo as module --- src/exo/rewrite/chexo/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/exo/rewrite/chexo/__init__.py diff --git a/src/exo/rewrite/chexo/__init__.py b/src/exo/rewrite/chexo/__init__.py new file mode 100644 index 000000000..e69de29bb From 71aea99e10950b4fdb8faa2869246e5efea6a770 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Thu, 24 Jul 2025 21:00:42 -0500 Subject: [PATCH 24/24] fix imports --- src/exo/rewrite/chexo/chexo.py | 2 +- src/exo/rewrite/chexo/coverage.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/exo/rewrite/chexo/chexo.py b/src/exo/rewrite/chexo/chexo.py index 843cb68ad..c983ba63d 100644 --- a/src/exo/rewrite/chexo/chexo.py +++ b/src/exo/rewrite/chexo/chexo.py @@ -12,7 +12,7 @@ from ...core.prelude import Sym, SrcInfo from ...core.memory import DRAM, Memory import numpy as np -from ..analysis import SchedulingError +from ..new_eff import SchedulingError from .constraint_solver import ( TRUE_CONSTRAINT, ConstraintMaker, diff --git a/src/exo/rewrite/chexo/coverage.py b/src/exo/rewrite/chexo/coverage.py index e6fd80e37..aba759564 100644 --- a/src/exo/rewrite/chexo/coverage.py +++ b/src/exo/rewrite/chexo/coverage.py @@ -12,7 +12,7 @@ ) from ...core.prelude import Sym from ...core.internal_cursors import NodePath -from ..analysis import SchedulingError +from ..new_eff import SchedulingError @dataclass