diff --git a/requirements.txt b/requirements.txt index 03ed1c373..5a054cd8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,6 @@ asdl==0.1.5 build==1.2.2.post1 z3-solver==4.14.0.0 yapf==0.43.0 +scipy==1.13.1 +hsnf==0.3.16 +pythonmonkey==1.1.0 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index dcdb9cf74..fc247a342 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,9 @@ install_requires = build>=1.2.1 z3-solver>=4.13.0.0 yapf>=0.40.2 + scipy>=1.13.1 + hsnf>=0.3.16 + pythonmonkey>=1.1.0 [options.packages.find] where = src diff --git a/src/exo/API.py b/src/exo/API.py index d00ff5de3..557d4f042 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -22,17 +22,18 @@ # Moved to new file from .core.proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc from .frontend.pyparser import get_ast_from_python, Parser, get_parent_scope -from .frontend.typecheck import TypeChecker +from .frontend.typecheck import TypeChecker, CheckMode from . import API_cursors as C from .core import internal_cursors as IC +from .backend.LoopIR_compiler import DEFAULT_CHECK_MODE # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # # Top-level decorator -def proc(f, _instr=None) -> "Procedure": +def proc(f, _instr=None, _check_mode: Optional[CheckMode] = None) -> "Procedure": if not isinstance(f, types.FunctionType): raise TypeError("@proc decorator must be applied to a function") @@ -42,14 +43,14 @@ def proc(f, _instr=None) -> "Procedure": parser = Parser( body, src_info, - parent_scope=get_parent_scope(depth=3 if _instr else 2), + parent_scope=get_parent_scope(depth=3 if _instr or _check_mode else 2), instr=_instr, as_func=True, ) - return Procedure(parser.result()) + return Procedure(parser.result(), _check_mode=_check_mode) -def instr(c_instr, c_global=""): +def instr(c_instr, c_global="", check_mode=None): if not isinstance(c_instr, str): raise TypeError("@instr decorator must be @instr()") @@ -57,7 +58,7 @@ def inner(f): if not isinstance(f, types.FunctionType): raise TypeError("@instr decorator must be applied to a function") - return proc(f, _instr=(c_instr, c_global)) + return proc(f, _instr=(c_instr, c_global), _check_mode=check_mode) return inner @@ -84,6 +85,14 @@ def parse_config(cls): return parse_config(_cls) +def chexo(f): + return proc(f, _check_mode="dynamic") + + +def chexo_debug(f): + return proc(f, _check_mode="both") + + # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # # iPython Display Object @@ -150,7 +159,11 @@ def compile_procs(proc_list, basedir: Path, c_file: str, h_file: str): def compile_procs_to_strings(proc_list, h_file_name: str): assert isinstance(proc_list, list) assert all(isinstance(p, Procedure) for p in proc_list) - return run_compile([p._loopir_proc for p in proc_list], h_file_name) + return run_compile( + [p._loopir_proc for p in proc_list], + h_file_name, + "dynamic" if any(p._check_mode == "dynamic" for p in proc_list) else "static", + ) class Procedure(ProcedureBase): @@ -160,14 +173,25 @@ def __init__( _provenance_eq_Procedure: "Procedure" = None, _forward=None, _mod_config=None, + _check_mode: Optional[CheckMode] = None, ): super().__init__() _mod_config = _mod_config or frozenset() + self._check_mode = ( + ( + _provenance_eq_Procedure._check_mode + if _provenance_eq_Procedure is not None + else DEFAULT_CHECK_MODE + ) + if _check_mode is None + else _check_mode + ) if isinstance(proc, LoopIR.UAST.proc): - proc = TypeChecker(proc).get_loopir() - CheckBounds(proc) + proc = TypeChecker(proc, self._check_mode).get_loopir() + if self._check_mode != "dynamic": + CheckBounds(proc) Check_Aliasing(proc) assert isinstance(proc, LoopIR.LoopIR.proc) @@ -294,7 +318,9 @@ def find_all(self, pattern): # ---------------------------------------------- # def c_code_str(self): - decls, defns = compile_to_strings("c_code_str", [self._loopir_proc]) + decls, defns = compile_to_strings( + "c_code_str", [self._loopir_proc], check_mode=self._check_mode + ) return decls + "\n" + defns def compile_c(self, directory: Path, filename: str): diff --git a/src/exo/API_scheduling.py b/src/exo/API_scheduling.py index c35bc280a..ac66295b9 100644 --- a/src/exo/API_scheduling.py +++ b/src/exo/API_scheduling.py @@ -889,7 +889,7 @@ def reorder_stmts(proc, block_cursor): s1 = block_cursor[0]._impl s2 = block_cursor[1]._impl - ir, fwd = scheduling.DoReorderStmt(s1, s2) + ir, fwd = scheduling.DoReorderStmt(s1, s2, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -897,7 +897,7 @@ def reorder_stmts(proc, block_cursor): def parallelize_loop(proc, loop_cursor): loop = loop_cursor._impl - ir, fwd = scheduling.DoParallelizeLoop(loop) + ir, fwd = scheduling.DoParallelizeLoop(loop, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -977,7 +977,7 @@ def rewrite_expr(proc, expr_cursor, new_expr): -> `s[ expr_cursor -> new_expr]` """ - ir, fwd = scheduling.DoRewriteExpr(expr_cursor._impl, new_expr) + ir, fwd = scheduling.DoRewriteExpr(expr_cursor._impl, new_expr, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1003,13 +1003,16 @@ def bind_expr(proc, expr_cursors, new_name): `a = b + 4.0` """ exprs = [ec._impl for ec in expr_cursors] - if any(not e._node.type.is_numeric() for e in exprs): + if ( + any(not e._node.type.is_numeric() for e in exprs) + and not proc._check_mode == "dynamic" + ): raise TypeError( "only numeric (not index or size) expressions " "can be bound by bind_expr()" ) - ir, fwd = scheduling.DoBindExpr(new_name, exprs) + ir, fwd = scheduling.DoBindExpr(new_name, exprs, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1202,10 +1205,14 @@ def bind_config(proc, var_cursor, config, field): e = var_cursor._impl._node cfg_f_type = config.lookup_type(field) - if not isinstance(e, LoopIR.Read): + if not isinstance(e, LoopIR.Read) and proc._check_mode != "dynamic": raise TypeError("expected a cursor to a single variable Read") - if not (e.type.is_real_scalar() and len(e.idx) == 0) and not e.type.is_bool(): + if ( + not (e.type.is_real_scalar() and len(e.idx) == 0) + and not e.type.is_bool() + and proc._check_mode != "dynamic" + ): raise TypeError( f"cannot bind non-real-scalar non-boolean value {e} to configuration states, since index and size expressions may depend on loop iteration" ) @@ -1216,7 +1223,9 @@ def bind_config(proc, var_cursor, config, field): f"to match type of Config variable ({cfg_f_type})" ) - ir, fwd, cfg = scheduling.DoBindConfig(config, field, var_cursor._impl) + ir, fwd, cfg = scheduling.DoBindConfig( + config, field, var_cursor._impl, proc._check_mode + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd, _mod_config=cfg) @@ -1232,10 +1241,39 @@ def delete_config(proc, stmt_cursor): rewrite: `s1 ; config.field = _ ; s3 -> s1 ; s3` """ - ir, fwd, cfg = scheduling.DoDeleteConfig(proc._root(), stmt_cursor._impl) + ir, fwd, cfg = scheduling.DoDeleteConfig( + proc._root(), stmt_cursor._impl, proc._check_mode + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd, _mod_config=cfg) +@sched_op([StmtCursorA]) +def delete_stmt(proc, stmt_cursor): + """ + delete a statement + + args: + stmt_cursor - cursor or pattern pointing at the statement to + be deleted + + rewrite: + `s1 ; s2 ; s3 -> s1 ; s3` + """ + ir, fwd = scheduling.DoDeleteStmt(stmt_cursor._impl, proc._check_mode) + return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) + + +@sched_op([GapCursorA, NewExprA("gap_cursor"), NewExprA("gap_cursor"), BoolA]) +def insert_mutate(proc, gap_cursor, buf_read, rhs, is_reduce): + if not (isinstance(buf_read, LoopIR.Read) and len(buf_read.idx) == 0): + raise SchedulingError() + new_stmt = (LoopIR.Reduce if is_reduce else LoopIR.Assign)( + buf_read.name, buf_read.type, [], rhs, buf_read.srcinfo + ) + ir, fwd = scheduling.DoInsertStmt(gap_cursor._impl, new_stmt, proc._check_mode) + return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) + + @sched_op([GapCursorA, ConfigA, ConfigFieldA, NewExprA("gap_cursor")]) def write_config(proc, gap_cursor, config, field, rhs): """ @@ -1269,7 +1307,9 @@ def write_config(proc, gap_cursor, config, field, rhs): ) stmt = stmtc._impl - ir, fwd, cfg = scheduling.DoConfigWrite(stmt, config, field, rhs, before=before) + ir, fwd, cfg = scheduling.DoConfigWrite( + stmt, config, field, rhs, proc._check_mode, before=before + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd, _mod_config=cfg) @@ -1318,11 +1358,13 @@ def resize_dim(proc, buf_cursor, dim_idx, size, offset, fold: bool = False): assert isinstance(size, LoopIR.Const) and size.val > 0 size = size.val buf_s = buf_cursor._impl - ir, fwd = scheduling.DoFoldBuffer(buf_s, dim_idx, size) + ir, fwd = scheduling.DoFoldBuffer(buf_s, dim_idx, size, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) else: # Normal resize operation - ir, fwd = scheduling.DoResizeDim(stmt_c, dim_idx, size, offset) + ir, fwd = scheduling.DoResizeDim( + stmt_c, dim_idx, size, offset, proc._check_mode + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1352,7 +1394,7 @@ def expand_dim(proc, buf_cursor, alloc_dim, indexing_expr): provided indexing expression is checked to make sure it is in-bounds """ stmt_c = buf_cursor._impl - ir, fwd = scheduling.DoExpandDim(stmt_c, alloc_dim, indexing_expr) + ir, fwd = scheduling.DoExpandDim(stmt_c, alloc_dim, indexing_expr, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1408,7 +1450,7 @@ def divide_dim(proc, alloc_cursor, dim_idx, quotient): if not (0 <= dim_idx < len(stmt._node.type.shape())): raise ValueError(f"Cannot divide out-of-bounds dimension index {dim_idx}") - ir, fwd = scheduling.DoDivideDim(stmt, dim_idx, quotient) + ir, fwd = scheduling.DoDivideDim(stmt, dim_idx, quotient, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1582,7 +1624,7 @@ def reuse_buffer(proc, buf_cursor, replace_cursor): """ buf_s = buf_cursor._impl rep_s = replace_cursor._impl - ir, fwd = scheduling.DoReuseBuffer(buf_s, rep_s) + ir, fwd = scheduling.DoReuseBuffer(buf_s, rep_s, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1653,7 +1695,12 @@ def stage_mem(proc, block_cursor, win_expr, new_buf_name, accum=False): """ buf_name, w_exprs = win_expr ir, fwd = scheduling.DoStageMem( - block_cursor._impl, buf_name, w_exprs, new_buf_name, use_accum_zero=accum + block_cursor._impl, + buf_name, + w_exprs, + new_buf_name, + proc._check_mode, + use_accum_zero=accum, ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1679,7 +1726,12 @@ def divide_with_recompute(proc, loop_cursor, outer_hi, outer_stride, new_iters): ` s[ i -> outer_stride * io + ii ]` """ ir, fwd = scheduling.DoDivideWithRecompute( - loop_cursor._impl, outer_hi, outer_stride, new_iters[0], new_iters[1] + loop_cursor._impl, + outer_hi, + outer_stride, + new_iters[0], + new_iters[1], + proc._check_mode, ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1736,12 +1788,74 @@ def divide_loop(proc, loop_cursor, div_const, new_iters, tail="guard", perfect=F quot=div_const, outer_iter=new_iters[0], inner_iter=new_iters[1], + check_mode=proc._check_mode, tail=tail, perfect=perfect, ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) +@sched_op( + [ + ForCursorA, + ConfigA, + ConfigFieldA("quot_config"), + ListA(NameA, length=2), + ] +) +def divide_loop_min(proc, loop_cursor, quot_config, quot_field, new_iters): + """ + Divide a loop into an outer and inner loop, where the inner loop + iterates over the range 0 to `div_const`. + + Old Name: In Halide and TVM, this was called "split" + + args: + loop_cursor - cursor pointing to the loop to split ; + can also be specified using the special shorthands + pattern: + or: # + div_const - integer > 1 specifying what to "divide by" + new_iters - list or tuple of two strings specifying the new + outer and inner iteration variable names + tail (opt) - specifies the strategy for handling the "remainder" + of the loop division (called the tail of the loop). + value can be "cut", "guard", or "cut_and_guard". + Default value: "guard" + perfect (opt) - Boolean (default False) that can be set to true + to assert that you know the remainder will always + be zero (i.e. there is no tail). You will get an + error if the compiler cannot verify this fact itself. + + rewrite: + divide(..., div_const=q, new_iters=['hi','lo'], tail='cut') + `for i in seq(0,e):` + ` s` + -> + `for hi in seq(0,e / q):` + ` for lo in seq(0, q):` + ` s[ i -> q*hi + lo ]` + `for lo in seq(0,e - q * (e / q)):` + ` s[ i -> q * (e / q) + lo ] + """ + + stmt = loop_cursor._impl + + ir, fwd = scheduling.DoDivideLoopMin( + stmt, + quot=LoopIR.ReadConfig( + quot_config, + quot_field, + quot_config.lookup_type(quot_field), + loop_cursor._impl._node.srcinfo, + ), + outer_iter=new_iters[0], + inner_iter=new_iters[1], + check_mode=proc._check_mode, + ) + return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) + + @sched_op([NestedForCursorA, NameA]) def mult_loops(proc, nested_loops, new_iter_name): """ @@ -1785,7 +1899,9 @@ def join_loops(proc, loop1_cursor, loop2_cursor): `for i in seq(lo, hi):` ` s` """ - ir, fwd = scheduling.DoJoinLoops(loop1_cursor._impl, loop2_cursor._impl) + ir, fwd = scheduling.DoJoinLoops( + loop1_cursor._impl, loop2_cursor._impl, proc._check_mode + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1813,7 +1929,7 @@ def cut_loop(proc, loop_cursor, cut_point): `for i in seq(cut, n):` ` s` """ - ir, fwd = scheduling.DoCutLoop(loop_cursor._impl, cut_point) + ir, fwd = scheduling.DoCutLoop(loop_cursor._impl, cut_point, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1836,7 +1952,7 @@ def shift_loop(proc, loop_cursor, new_lo): `for i in seq(new_lo, new_lo + n - m):` ` s(i + (m - new_lo))` """ - ir, fwd = scheduling.DoShiftLoop(loop_cursor._impl, new_lo) + ir, fwd = scheduling.DoShiftLoop(loop_cursor._impl, new_lo, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -1871,7 +1987,7 @@ def reorder_loops(proc, nested_loops): if len(stmt_c.body()) != 1 or not isinstance(stmt_c.body()[0]._node, LoopIR.For): raise ValueError(f"expected loop directly inside of {stmt_c._node.iter} loop") - ir, fwd = scheduling.DoLiftScope(stmt_c.body()[0]) + ir, fwd = scheduling.DoLiftScope(stmt_c.body()[0], proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2064,7 +2180,7 @@ def fission(proc, gap_cursor, n_lifts=1, unsafe_disable_checks=False): ) ir, fwd = scheduling.DoFissionAfterSimple( - stmt._impl, n_lifts, unsafe_disable_checks + stmt._impl, n_lifts, unsafe_disable_checks, proc._check_mode ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2146,9 +2262,9 @@ def fuse(proc, stmt1, stmt2, unsafe_disable_check=False): s1 = stmt1._impl s2 = stmt2._impl if isinstance(stmt1, PC.IfCursor): - ir, fwd = scheduling.DoFuseIf(s1, s2) + ir, fwd = scheduling.DoFuseIf(s1, s2, proc._check_mode) else: - ir, fwd = scheduling.DoFuseLoop(s1, s2, unsafe_disable_check) + ir, fwd = scheduling.DoFuseLoop(s1, s2, proc._check_mode, unsafe_disable_check) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2168,7 +2284,9 @@ def remove_loop(proc, loop_cursor, unsafe_disable_check=False): -> `s` """ - ir, fwd = scheduling.DoRemoveLoop(loop_cursor._impl, unsafe_disable_check) + ir, fwd = scheduling.DoRemoveLoop( + loop_cursor._impl, unsafe_disable_check, proc._check_mode + ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2203,7 +2321,7 @@ def add_loop( stmt_c = block_cursor[0]._impl ir, fwd = scheduling.DoAddLoop( - stmt_c, iter_name, hi_expr, guard, unsafe_disable_check + stmt_c, iter_name, hi_expr, guard, unsafe_disable_check, proc._check_mode ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2257,7 +2375,7 @@ def lift_scope(proc, scope_cursor): """ stmt_c = scope_cursor._impl - ir, fwd = scheduling.DoLiftScope(stmt_c) + ir, fwd = scheduling.DoLiftScope(stmt_c, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) @@ -2279,7 +2397,7 @@ def eliminate_dead_code(proc, stmt_cursor): `s1` """ - ir, fwd = scheduling.DoEliminateDeadCode(stmt_cursor._impl) + ir, fwd = scheduling.DoEliminateDeadCode(stmt_cursor._impl, proc._check_mode) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) diff --git a/src/exo/__init__.py b/src/exo/__init__.py index 95fe0c050..d44c7aaa3 100644 --- a/src/exo/__init__.py +++ b/src/exo/__init__.py @@ -5,6 +5,8 @@ proc, instr, config, + chexo, + chexo_debug, ExoType, ) from .rewrite.LoopIR_scheduling import SchedulingError diff --git a/src/exo/backend/LoopIR_compiler.py b/src/exo/backend/LoopIR_compiler.py index 090a215de..7bd41eb5a 100644 --- a/src/exo/backend/LoopIR_compiler.py +++ b/src/exo/backend/LoopIR_compiler.py @@ -15,6 +15,10 @@ from ..core.prelude import * from .win_analysis import WindowAnalysis from ..rewrite.range_analysis import IndexRangeEnvironment +from ..rewrite.chexo.chexo import fuzz, fuzz_single_scope +from ..core.internal_cursors import Cursor + +DEFAULT_CHECK_MODE = "static" def sanitize_str(s): @@ -52,7 +56,11 @@ def sanitize_str(s): def lift_to_cir(e, range_env): assert e.type.is_indexable(), "why are you here?" - is_non_neg = lambda e: range_env.check_expr_bound(0, IndexRangeEnvironment.leq, e) + is_non_neg = lambda e: ( + False + if range_env is None + else range_env.check_expr_bound(0, IndexRangeEnvironment.leq, e) + ) if isinstance(e, LoopIR.Read): return CIR.Read(e.name, is_non_neg(e)) @@ -320,10 +328,10 @@ def window_struct(base_type, n_dims, is_const) -> WindowStruct: # top level compiler function called by tests! -def run_compile(proc_list, h_file_name: str): +def run_compile(proc_list, h_file_name: str, check_mode=None): file_stem = str(Path(h_file_name).stem) lib_name = sanitize_str(file_stem) - fwd_decls, body = compile_to_strings(lib_name, proc_list) + fwd_decls, body = compile_to_strings(lib_name, proc_list, check_mode) source = f'#include "{h_file_name}"\n\n{body}' @@ -360,7 +368,8 @@ def run_compile(proc_list, h_file_name: str): } -def compile_to_strings(lib_name, proc_list): +def compile_to_strings(lib_name, proc_list, check_mode=None): + check_mode = DEFAULT_CHECK_MODE if check_mode is None else check_mode # Get transitive closure of call-graph orig_procs = [id(p) for p in proc_list] @@ -407,12 +416,19 @@ def from_lines(x): else: is_public_decl = id(p) in orig_procs - p = ParallelAnalysis().run(p) + if check_mode != "dynamic": + # fixme: need to check parallel analysis on static procs even if one of the procs is dynamic + p = ParallelAnalysis().run(p) + else: + proc_cursor = Cursor.create(p).body() + fuzz_single_scope(proc_cursor) p = PrecisionAnalysis().run(p) p = WindowAnalysis().apply_proc(p) p = MemoryAnalysis().run(p) - comp = Compiler(p, ctxt_name, is_public_decl=is_public_decl) + comp = Compiler( + p, ctxt_name, is_public_decl=is_public_decl, check_mode=check_mode + ) d, b = comp.comp_top() struct_defns |= comp.struct_defns() needed_helpers |= comp.needed_helpers() @@ -522,13 +538,15 @@ def _compile_context_struct(configs, lib_name): class Compiler: - def __init__(self, proc, ctxt_name, *, is_public_decl): + def __init__(self, proc, ctxt_name, *, is_public_decl, check_mode): assert isinstance(proc, LoopIR.proc) self.proc = proc self.ctxt_name = ctxt_name self.env = ChainMap() - self.range_env = IndexRangeEnvironment(proc, fast=False) + self.range_env = ( + None if check_mode == "dynamic" else IndexRangeEnvironment(proc, fast=False) + ) self.names = ChainMap() self.envtyp = dict() self.mems = dict() @@ -693,12 +711,14 @@ def new_varname(self, symbol, typ, mem=None): def push(self, only=None): if only is None: self.env = self.env.new_child() - self.range_env.enter_scope() + if self.range_env is not None: + self.range_env.enter_scope() self.names = self.names.new_child() self._tab = self._tab + " " elif only == "env": self.env = self.env.new_child() - self.range_env.enter_scope() + if self.range_env is not None: + self.range_env.enter_scope() self.names = self.names.new_child() elif only == "tab": self._tab = self._tab + " " @@ -707,7 +727,8 @@ def push(self, only=None): def pop(self): self.env = self.env.parents - self.range_env.exit_scope() + if self.range_env is not None: + self.range_env.exit_scope() self.names = self.names.parents self._tab = self._tab[:-2] @@ -894,11 +915,12 @@ def comp_s(self, s): hi = self.comp_e(s.hi) self.push(only="env") itr = self.new_varname(s.iter, typ=T.index) # allocate a new string - self.range_env.add_loop_iter( - s.iter, - s.lo, - s.hi, - ) + if self.range_env is not None: + self.range_env.add_loop_iter( + s.iter, + s.lo, + s.hi, + ) if isinstance(s.loop_mode, LoopIR.Par): self.add_line(f"#pragma omp parallel for") self.add_line(f"for (int_fast32_t {itr} = {lo}; {itr} < {hi}; {itr}++) {{") @@ -1035,7 +1057,9 @@ def comp_e(self, e, prec=0): rhs = self.comp_e(e.rhs, local_prec + 1) if int_div: - if self.range_env.check_expr_bound(0, IndexRangeEnvironment.leq, e): + if self.range_env is None or self.range_env.check_expr_bound( + 0, IndexRangeEnvironment.leq, e + ): # TODO: too many parens? return f"(({lhs}) / ({rhs}))" return self._call_static_helper("exo_floor_div", lhs, rhs) diff --git a/src/exo/core/extern.py b/src/exo/core/extern.py index b1ae39d6d..d4322c8e3 100644 --- a/src/exo/core/extern.py +++ b/src/exo/core/extern.py @@ -32,5 +32,11 @@ def typecheck(self, args): def interpret(self, args): raise NotImplementedError() + def transpile(self, args): + raise NotImplementedError() + def compile(self, args, prim_type): raise NotImplementedError() + + def express_in_constraints(self, args, out_sym): + raise NotImplementedError() diff --git a/src/exo/core/internal_cursors.py b/src/exo/core/internal_cursors.py index 1d4c98a8f..9f2957790 100644 --- a/src/exo/core/internal_cursors.py +++ b/src/exo/core/internal_cursors.py @@ -581,6 +581,11 @@ def forward(cursor: Node): return forward +@dataclass(frozen=True) +class NodePath: + path: tuple[tuple[str, Optional[int]], ...] + + @dataclass class Node(Cursor): _path: list[tuple[str, Optional[int]]] @@ -605,6 +610,13 @@ def _node(self): return n + # ------------------------------------------------------------------------ # + # Hashable path accessor + # ------------------------------------------------------------------------ # + + def get_path(self) -> NodePath: + return NodePath(tuple(self._path)) + # ------------------------------------------------------------------------ # # Navigation (implementation) # ------------------------------------------------------------------------ # diff --git a/src/exo/frontend/typecheck.py b/src/exo/frontend/typecheck.py index 9c89cb7e9..9bd237694 100644 --- a/src/exo/frontend/typecheck.py +++ b/src/exo/frontend/typecheck.py @@ -1,3 +1,4 @@ +from typing import Literal, Union from ..core.LoopIR import ( T, UAST, @@ -10,7 +11,6 @@ from ..core.extern import Extern_Typecheck_Error from ..core.memory import * - # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # @@ -35,6 +35,10 @@ # The typechecker +def is_mutable_index(t: LoopIR.type): + return isinstance(t, (T.INT8, T.UINT8, T.UINT16, T.INT32)) + + def check_call_types(err_handler, args, call_args): for call_a, sig_a in zip(args, call_args): if call_a.type == T.err: @@ -46,6 +50,7 @@ def check_call_types(err_handler, args, call_args): "expected size or index type " "expression, " f"but got type {call_a.type}", + allowed_in_chexo=is_mutable_index(call_a.type), ) elif sig_a.type is T.bool: @@ -92,11 +97,16 @@ def check_call_types(err_handler, args, call_args): assert False, "bad argument type case" +Checker = Literal["static", "dynamic"] +CheckMode = Union[Checker, Literal["both"]] + + class TypeChecker: - def __init__(self, proc): + def __init__(self, proc, check_mode: CheckMode): self.uast_proc = proc self.env = dict() self.errors = [] + self.check_mode = check_mode args = [] for a in proc.args: @@ -126,6 +136,7 @@ def __init__(self, proc): self.err( proc, f"expected writes to configuration {name[0].name()}.{name[1]} does not depend on loop iterations", + allowed_in_chexo=True, ) instr = proc.instr @@ -150,8 +161,9 @@ def __init__(self, proc): def get_loopir(self): return self.loopir_proc - def err(self, node, msg): - self.errors.append(f"{node.srcinfo}: {msg}") + def err(self, node, msg, *, allowed_in_chexo=False): + if not allowed_in_chexo or self.check_mode != "dynamic": + self.errors.append(f"{node.srcinfo}: {msg}") def check_stmts(self, body): assert len(body) > 0 or self.uast_proc.instr @@ -165,7 +177,11 @@ def check_access(self, node, nm, idx, lvalue=False): idx = [self.check_e(i, is_index=True) for i in idx] for i in idx: if i.type != T.err and not i.type.is_indexable(): - self.err(i, f"cannot index with expression of type '{i.type}'") + self.err( + i, + f"cannot index with expression of type '{i.type}'", + allowed_in_chexo=is_mutable_index(i.type), + ) # check compatibility with buffer type typ = self.env[nm] @@ -236,7 +252,11 @@ def check_single_stmt(self, stmt): if isinstance(stmt, (UAST.Assign, UAST.Reduce)): rhs = self.check_e(stmt.rhs) if rhs.type != T.err and not rhs.type.is_real_scalar(): - self.err(rhs, f"cannot assign/reduce a '{rhs.type}' type value") + self.err( + rhs, + f"cannot assign/reduce a '{rhs.type}' type value", + allowed_in_chexo=rhs.type.is_indexable(), + ) idx, typ = self.check_access(stmt, stmt.name, stmt.idx, lvalue=True) assert typ.is_real_scalar() or typ is T.err @@ -266,6 +286,8 @@ def check_single_stmt(self, stmt): rhs, f"expected a real scalar value, but " f"got an expression of type {rhs.type}", + allowed_in_chexo=is_mutable_index(ftyp) + and rhs.type.is_indexable(), ) elif ftyp.is_indexable(): if not rhs.type.is_indexable(): @@ -273,6 +295,7 @@ def check_single_stmt(self, stmt): rhs, f"expected an index or size type " f"expression, but got type {rhs.type}", + allowed_in_chexo=is_mutable_index(rhs.type), ) elif ftyp == T.bool: if rhs.type != T.bool: @@ -318,10 +341,18 @@ def check_single_stmt(self, stmt): lo = self.check_e(stmt.cond.lo, is_index=True) if lo.type != T.err and not lo.type.is_indexable(): - self.err(lo, "expected loop bound to be indexable.") + self.err( + lo, + "expected loop bound to be indexable.", + allowed_in_chexo=is_mutable_index(lo.type), + ) hi = self.check_e(stmt.cond.hi, is_index=True) if hi.type != T.err and not hi.type.is_indexable(): - self.err(hi, "expected loop bound to be indexable.") + self.err( + hi, + "expected loop bound to be indexable.", + allowed_in_chexo=is_mutable_index(hi.type), + ) body = self.check_stmts(stmt.body) if isinstance(stmt.cond, UAST.SeqRange): @@ -357,7 +388,11 @@ def check_w_access(self, e, orig_hi): if isinstance(e, UAST.Point): pt = self.check_e(e.pt, is_index=True) if pt.type != T.err and not pt.type.is_indexable(): - self.err(pt, f"cannot index with expression of type '{pt.type}'") + self.err( + pt, + f"cannot index with expression of type '{pt.type}'", + allowed_in_chexo=is_mutable_index(pt.type), + ) return LoopIR.Point(pt, e.srcinfo) elif isinstance(e, UAST.Interval): @@ -366,14 +401,22 @@ def check_w_access(self, e, orig_hi): else: lo = self.check_e(e.lo, is_index=True) if lo.type != T.err and not lo.type.is_indexable(): - self.err(lo, f"cannot index with expression of type '{lo.type}'") + self.err( + lo, + f"cannot index with expression of type '{lo.type}'", + allowed_in_chexo=is_mutable_index(lo.type), + ) if e.hi is None: hi = orig_hi else: hi = self.check_e(e.hi, is_index=True) if hi.type != T.err and not hi.type.is_indexable(): - self.err(hi, f"cannot index with expression of type '{hi.type}'") + self.err( + hi, + f"cannot index with expression of type '{hi.type}'", + allowed_in_chexo=is_mutable_index(hi.type), + ) return LoopIR.Interval(lo, hi, e.srcinfo) @@ -472,20 +515,33 @@ def check_e(self, e, is_index=False): operand, f"expected 'index' or 'size' argument to " f"comparison op: {e.op}", + allowed_in_chexo=operand.type.is_real_scalar(), ) typ = T.bool elif e.op in ("+", "-", "*", "/", "%"): if lhs.type.is_real_scalar(): if not rhs.type.is_real_scalar(): - self.err(rhs, "expected scalar type") - typ = T.err + self.err( + rhs, + "expected scalar type", + allowed_in_chexo=rhs.type.is_indexable(), + ) elif e.op == "%": - self.err(e, "cannot compute modulus of 'R' values") - typ = T.err - else: - typ = lhs.type + self.err( + e, + "cannot compute modulus of 'R' values", + allowed_in_chexo=is_mutable_index(lhs.type) + and is_mutable_index(rhs.type), + ) + typ = lhs.type elif rhs.type.is_real_scalar(): - self.err(lhs, "expected scalar type") + self.err( + lhs, + "expected scalar type", + allowed_in_chexo=is_mutable_index(lhs.type) + and is_mutable_index(rhs.type), + ) + typ = lhs.type elif lhs.type == T.bool or rhs.type == T.bool: node = lhs if lhs.type == T.bool else rhs self.err(node, "cannot perform arithmetic on 'bool' values") @@ -505,15 +561,15 @@ def check_e(self, e, is_index=False): self.err( rhs, "cannot divide or modulo by a " "non-constant value", + allowed_in_chexo=True, ) - typ = T.err elif rhs.val <= 0: self.err( rhs, "cannot divide or modulo by zero " "or a negative value", + allowed_in_chexo=True, ) - typ = T.err typ = lhs.type elif e.op == "*": @@ -527,8 +583,9 @@ def check_e(self, e, is_index=False): "cannot multiply two non-constant " "indexing/sizing expressions, since " "the result would be non-affine", + allowed_in_chexo=True, ) - typ = T.err + typ = lhs.type else: # + or - if lhs.type == T.index or rhs.type == T.index: typ = T.index @@ -617,6 +674,7 @@ def check_t(self, typ): h, "expected array size expression " "to have type 'size' or type 'index'", + allowed_in_chexo=is_mutable_index(h.type), ) return T.Tensor(hi, typ.is_window, sub_typ) else: diff --git a/src/exo/libs/externs.py b/src/exo/libs/externs.py index 8752ed695..472224db1 100644 --- a/src/exo/libs/externs.py +++ b/src/exo/libs/externs.py @@ -1,4 +1,8 @@ -from exo.core.extern import Extern, _EErr +from ..core.extern import Extern, _EErr +import numpy as np + +from ..rewrite.chexo.constraint_solver import Constraint, DisjointConstraint, Expression +from ..core.prelude import Sym class _Sin(Extern): @@ -20,8 +24,11 @@ def typecheck(self, args): def globl(self, prim_type): return "#include " - # def interpret(self, args): - # return math.sin(args[0]) + def interpret(self, args): + return np.sin(args[0]) + + def transpile(self, args): + return f"Math.sin({args[0]})" def compile(self, args, prim_type): return f"sin(({prim_type}){args[0]})" @@ -55,15 +62,37 @@ def globl(self, prim_type): ) return s - # def interpret(self, args): - # if args[0] > 0: - # return args[0] - # else: - # return 0 + def interpret(self, args): + if args[0] > 0: + return args[0] + else: + return 0 + + def transpile(self, args): + return f"(({args[0]}>0)?{args[0]}:0)" def compile(self, args, prim_type): return f"_relu_{prim_type}(({prim_type}){args[0]})" + def express_in_constraints( + self, args: tuple[Expression, ...], out_sym: Sym + ) -> DisjointConstraint: + result_expr = Expression.from_sym(out_sym) + return ( + Constraint(args[0], True) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + args[0].add(result_expr.negate()), False + ).lift_to_disjoint_constraint() + ) + .union( + Constraint(args[0].negate(), True) + .lift_to_disjoint_constraint() + .intersect(Constraint(result_expr, False).lift_to_disjoint_constraint()) + ) + ) + relu = _Relu() @@ -95,19 +124,47 @@ def globl(self, prim_type): ) return s - # def interpret(self, args): - # x = args[0] - # v = args[1] - # y = args[2] - # z = args[3] - # if x < v: - # return y - # else: - # return z + def interpret(self, args): + x = args[0] + v = args[1] + y = args[2] + z = args[3] + if x < v: + return y + else: + return z + + def transpile(self, args): + return f"(({args[0]}<{args[1]})?{args[2]}:{args[3]})" def compile(self, args, prim_type): return f"_select_{prim_type}(({prim_type}){args[0]}, ({prim_type}){args[1]}, ({prim_type}){args[2]}, ({prim_type}){args[3]})" + def express_in_constraints( + self, args: tuple[Expression, ...], out_sym: Sym + ) -> DisjointConstraint: + result_expr = Expression.from_sym(out_sym) + return ( + Constraint( + args[1].add(args[0].add(Expression.from_constant(1)).negate()), True + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + args[2].add(result_expr.negate()), False + ).lift_to_disjoint_constraint() + ) + .union( + Constraint(args[0].add(args[1].negate()), True) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + args[3].add(result_expr.negate()), False + ).lift_to_disjoint_constraint() + ) + ) + ) + select = _Select() @@ -131,8 +188,11 @@ def typecheck(self, args): def globl(self, prim_type): return "#include " - # def interpret(self, args): - # return math.expf(args[0]) + def interpret(self, args): + return np.exp(args[0]) + + def transpile(self, args): + return f"Math.exp({args[0]})" def compile(self, args, prim_type): return f"expf(({prim_type})({args[0]}))" @@ -161,8 +221,11 @@ def typecheck(self, args): def globl(self, prim_type): return "#include " - # def interpret(self, args): - # return math.fmaxf(args[0], args[1]) + def interpret(self, args): + return np.nanmax([args[0], args[1]]) + + def transpile(self, args): + return f"(({args[0]}<{args[1]}&&{args[1]}=={args[1]})?{args[1]}:{args[0]})" def compile(self, args, prim_type): return f"fmaxf(({prim_type})({args[0]}), ({prim_type})({args[1]}))" @@ -195,8 +258,11 @@ def globl(self, prim_type): }} """ - # def interpret(self, args): - # return math.sigmoid(args[0]) + def interpret(self, args): + return 1 / (1 + np.exp(-args[0])) + + def transpile(self, args): + return f"1/(1+Math.exp(-{args[0]}))" def compile(self, args, prim_type): return f"sigmoid(({prim_type})({args[0]}))" @@ -224,11 +290,83 @@ def typecheck(self, args): def globl(self, prim_type): return "#include " - # def interpret(self, args): - # return math.sqrt(args[0]) + def interpret(self, args): + return np.sqrt(args[0]) + + def transpile(self, args): + return f"Math.sqrt({args[0]})" def compile(self, args, prim_type): return f"sqrt(({prim_type})({args[0]}))" sqrt = _Sqrt() + + +class _IntMin(Extern): + def __init__(self): + super().__init__("intmin") + + def typecheck(self, args): + if len(args) != 2: + raise _EErr(f"expected 2 arguments, got {len(args)}") + + for i in range(len(args)): + atyp = args[i].type + if not atyp.is_indexable() and not atyp.is_real_scalar(): + raise _EErr( + f"expected argument {i+1} to be a real scalar value or " + f"control flow value, but got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + s = ( + f"{prim_type} _intmin_{prim_type}({prim_type} x,{prim_type} v)" + " {\n" + " if (x < v) return x;\n" + " else return v;\n" + "}\n" + ) + return s + + def interpret(self, args): + x = args[0] + v = args[1] + if x < v: + return x + else: + return v + + def transpile(self, args): + return f"(({args[0]}<{args[1]})?{args[0]}:{args[1]})" + + def compile(self, args, prim_type): + return f"_intmin_{prim_type}(({prim_type}){args[0]}, ({prim_type}){args[1]})" + + def express_in_constraints( + self, args: tuple[Expression, ...], out_sym: Sym + ) -> DisjointConstraint: + result_expr = Expression.from_sym(out_sym) + return ( + Constraint( + args[1].add(args[0].add(Expression.from_constant(1)).negate()), True + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + args[0].add(result_expr.negate()), False + ).lift_to_disjoint_constraint() + ) + .union( + Constraint(args[0].add(args[1].negate()), True) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + args[1].add(result_expr.negate()), False + ).lift_to_disjoint_constraint() + ) + ) + ) + + +intmin = _IntMin() diff --git a/src/exo/platforms/gemmini.py b/src/exo/platforms/gemmini.py index 5f598817b..c52a83efa 100644 --- a/src/exo/platforms/gemmini.py +++ b/src/exo/platforms/gemmini.py @@ -386,8 +386,8 @@ def ld_i8_block( assert n <= 16 assert m <= 4 assert stride(src, 1) == 1 - assert stride(dst, 0) == 16 - assert stride(dst, 1) == 1 + assert stride(dst, 1) == 16 + assert stride(dst, 2) == 1 for i in seq(0, n): for j in seq(0, m): @@ -481,8 +481,8 @@ def zero_block_id2( ): assert n <= 16 assert m <= 4 - assert stride(dst, 0) == 16 - assert stride(dst, 1) == 1 + assert stride(dst, 1) == 16 + assert stride(dst, 2) == 1 for i in seq(0, n): for j in seq(0, m): @@ -620,7 +620,7 @@ def ld_i8_vector( src: [i8][16] @ DRAM, dst: [i8][16] @ GEMM_SCRATCH, ): - assert stride(dst, 0) == 16 + assert stride(dst, 0) == 1 for i in seq(0, 16): dst[i] = src[i] @@ -636,7 +636,7 @@ def do_ld_i8_vector( src: [i8][16] @ DRAM, dst: [i8][16] @ GEMM_SCRATCH, ): - assert stride(dst, 0) == 16 + assert stride(dst, 0) == 1 for i in seq(0, 16): dst[i] = src[i] @@ -736,8 +736,8 @@ def ld_acc_i32_vector( dst: [i32][n, 16] @ GEMM_ACCUM, ): assert n <= 16 - assert stride(dst, 0) == 1 - assert stride(src, 0) == 1 + assert stride(dst, 1) == 1 + assert stride(src, 1) == 1 for i in seq(0, n): for j in seq(0, 16): @@ -754,8 +754,8 @@ def do_ld_acc_i32_vector( dst: [i32][n, 16] @ GEMM_ACCUM, ): assert n <= 16 - assert stride(dst, 0) == 1 - assert stride(src, 0) == 1 + assert stride(dst, 1) == 1 + assert stride(src, 1) == 1 for i in seq(0, n): for j in seq(0, 16): @@ -1088,7 +1088,7 @@ def del_and_zero(p): def zero_i8_vector( dst: [i8][16] @ GEMM_SCRATCH, ): - assert stride(dst, 0) == 16 + assert stride(dst, 0) == 1 pass for i in seq(0, 16): @@ -1102,7 +1102,7 @@ def zero_i8_vector( def do_zero_i8_vector( dst: [i8][16] @ GEMM_SCRATCH, ): - assert stride(dst, 0) == 16 + assert stride(dst, 0) == 1 for i in seq(0, 16): dst[i] = 0.0 diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index c67f0ec70..2ad97208e 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -1,6 +1,17 @@ import re from collections import ChainMap -from typing import List, Tuple, Optional +import traceback +from typing import ( + Any, + Callable, + Generator, + List, + Literal, + Tuple, + Optional, + TypeVar, + Union, +) from ..core.LoopIR import ( LoopIR, @@ -32,8 +43,11 @@ Check_ExprBound, Check_Aliasing, ) +from .chexo.chexo import fuzz, fuzz_single_scope +from .chexo.LoopIR_transpiler import StageMemArgs, StagedWindowExpr from .range_analysis import IndexRangeEnvironment, IndexRange, index_range_analysis +from ..core.internal_cursors import Block, Cursor, Node from ..core.prelude import * from ..core.proc_eqv import get_strictest_eqv_proc @@ -41,7 +55,8 @@ import exo.API as api from ..frontend.pattern_match import match_pattern from ..core.memory import DRAM -from ..frontend.typecheck import check_call_types +from ..frontend.typecheck import Checker, check_call_types, CheckMode +from ..libs.externs import intmin from functools import partial @@ -206,23 +221,77 @@ def _replace_pats(ir, fwd, c, pat, repl, only_replace_attrs=True, use_sym_id=Tru todos.append((rd, c_repl)) cur_fwd = lambda x: x - for (rd, c_repl) in todos: + for rd, c_repl in todos: rd = cur_fwd(rd) ir, fwd_rd = _replace_helper(rd, c_repl, only_replace_attrs) cur_fwd = _compose(fwd_rd, cur_fwd) return ir, _compose(cur_fwd, fwd) +def find_all_sym_reads(c: Union[Block, Node], sym: Sym) -> Generator[Node, None, None]: + if isinstance(c, Block): + for node in c: + yield from find_all_sym_reads(node, sym) + else: + if isinstance(c._node, (LoopIR.Assign, LoopIR.Reduce)): + yield from find_all_sym_reads(c._child_block("idx"), sym) + yield from find_all_sym_reads(c._child_node("rhs"), sym) + elif isinstance(c._node, LoopIR.WriteConfig): + yield from find_all_sym_reads(c._child_node("rhs"), sym) + elif isinstance(c._node, LoopIR.If): + yield from find_all_sym_reads(c._child_node("cond"), sym) + yield from find_all_sym_reads(c._child_block("body"), sym) + yield from find_all_sym_reads(c._child_block("orelse"), sym) + elif isinstance(c._node, LoopIR.For): + yield from find_all_sym_reads(c._child_node("lo"), sym) + yield from find_all_sym_reads(c._child_node("hi"), sym) + yield from find_all_sym_reads(c._child_block("body"), sym) + elif isinstance(c._node, LoopIR.Alloc): + yield from find_all_sym_reads(c._child_node("type"), sym) + elif isinstance(c._node, LoopIR.Free): + yield from find_all_sym_reads(c._child_node("type"), sym) + elif isinstance(c._node, LoopIR.Call): + yield from find_all_sym_reads(c._child_block("args"), sym) + elif isinstance(c._node, LoopIR.WindowStmt): + yield from find_all_sym_reads(c._child_node("rhs"), sym) + elif isinstance(c._node, LoopIR.expr): + yield from find_all_sym_reads(c._child_node("type"), sym) + if isinstance(c._node, LoopIR.Read): + yield from find_all_sym_reads(c._child_block("idx"), sym) + if c._node.name == sym: + yield c + elif isinstance(c._node, LoopIR.USub): + yield from find_all_sym_reads(c._child_node("arg"), sym) + elif isinstance(c._node, LoopIR.BinOp): + yield from find_all_sym_reads(c._child_node("lhs"), sym) + yield from find_all_sym_reads(c._child_node("rhs"), sym) + elif isinstance(c._node, LoopIR.Extern): + yield from find_all_sym_reads(c._child_block("args"), sym) + elif isinstance(c._node, LoopIR.WindowExpr): + yield from find_all_sym_reads(c._child_block("idx"), sym) + elif isinstance(c._node, LoopIR.Tensor): + yield from find_all_sym_reads(c._child_block("hi"), sym) + elif isinstance(c._node, LoopIR.WindowType): + yield from find_all_sym_reads(c._child_node("src_type"), sym) + yield from find_all_sym_reads(c._child_node("as_tensor"), sym) + yield from find_all_sym_reads(c._child_block("idx"), sym) + elif isinstance(c._node, LoopIR.Interval): + yield from find_all_sym_reads(c._child_node("lo"), sym) + yield from find_all_sym_reads(c._child_node("hi"), sym) + elif isinstance(c._node, LoopIR.Point): + yield from find_all_sym_reads(c._child_node("pt"), sym) + + def _replace_reads(ir, fwd, c, sym, repl, only_replace_attrs=True): c = fwd(c) todos = [] - for rd in match_pattern(c, f"{repr(sym)}[_]", use_sym_id=True): - # Need [_] to pattern match against window expressions + for rd in find_all_sym_reads(c, sym): + # Need [_] to pattern match against window expressiontatic if c_repl := repl(rd): todos.append((rd, c_repl)) cur_fwd = lambda x: x - for (rd, c_repl) in todos: + for rd, c_repl in todos: rd = cur_fwd(rd) ir, fwd_rd = _replace_helper(rd, c_repl, only_replace_attrs) cur_fwd = _compose(fwd_rd, cur_fwd) @@ -249,7 +318,7 @@ def _replace_writes( todos.append((s, c_repl)) cur_fwd = lambda x: x - for (s, c_repl) in todos: + for s, c_repl in todos: s = cur_fwd(s) ir, fwd_s = _replace_helper(s, c_repl, only_replace_attrs) cur_fwd = _compose(fwd_s, cur_fwd) @@ -366,111 +435,197 @@ def divide_expr(e, quot): # --------------------------------------------------------------------------- # # Scheduling directives +CheckResult = Union[ + tuple[LoopIR.proc, Callable[[Cursor], Cursor]], + tuple[ + LoopIR.proc, + Callable[ + [Cursor], + Cursor, + ], + Any, + ], +] + + +def do_check( + check: Callable[[Checker], CheckResult], + mode: CheckMode, +) -> CheckResult: + if mode == "both": + e_static, e_dynamic = None, None + trb_static, trb_dynamic = None, None + static_res, dynamic_res = None, None + try: + static_res = check("static") + except Exception as e: + e_static = e + trb_static = traceback.format_exc() + try: + dynamic_res = check("dynamic") + except Exception as e: + e_dynamic = e + trb_dynamic = traceback.format_exc() + if (e_static is None) != (e_dynamic is None): + assert ( + False + ), f"fuzzer should match static analysis\ntrb_static: {trb_static}\n\ntrb_dynamic: {trb_dynamic}" + elif e_static is not None: + raise e_static + else: + assert static_res is not None and dynamic_res is not None + if str(static_res[0]) != str(dynamic_res[0]): + assert ( + False + ), f"resulting object code differs between static and dynamic.\nstatic:\n{str(static_res[0])}\ndynamic:\n{str(dynamic_res[0])}" + return static_res + else: + return check(mode) + # Take a conservative approach and allow stmt reordering only when they are # writing to different buffers # TODO: Do effectcheck's check_commutes-ish thing using SMT here -def DoReorderStmt(f_cursor, s_cursor): - if f_cursor.next() != s_cursor: - raise SchedulingError( - "expected the second statement to be directly after the first" - ) - Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node) - ir, fwd = s_cursor._move(f_cursor.before()) - return ir, fwd +def DoReorderStmt(f_cursor, s_cursor, check_mode: CheckMode): + def check(checker: Checker): + if f_cursor.next() != s_cursor: + raise SchedulingError( + "expected the second statement to be directly after the first" + ) + ir, fwd = s_cursor._move(f_cursor.before()) + if checker == "dynamic": + fuzz( + f_cursor.as_block().expand(0, 1), fwd(s_cursor).as_block().expand(0, 1) + ) + else: + Check_ReorderStmts(f_cursor.get_root(), f_cursor._node, s_cursor._node) + return ir, fwd + return do_check( + check, + check_mode, + ) -def DoParallelizeLoop(loop_cursor): - return loop_cursor._child_node("loop_mode")._replace(LoopIR.Par()) +def DoParallelizeLoop(loop_cursor, check_mode: CheckMode): + def check(checker: Checker): + ir, fwd = loop_cursor._child_node("loop_mode")._replace(LoopIR.Par()) + return ir, fwd -def DoJoinLoops(loop1_c, loop2_c): - if loop1_c.next() != loop2_c: - raise SchedulingError("expected the second loop to be directly after the first") + return do_check(check, check_mode) - loop1 = loop1_c._node - loop2 = loop2_c._node - try: - Check_ExprEqvInContext(loop1_c.get_root(), loop1.hi, [loop1], loop2.lo, [loop2]) - except Exception as e: - raise SchedulingError( - f"expected the first loop upper bound {loop1.hi} to be the same as the second loop lower bound {loop2.lo}" - ) +def DoJoinLoops(loop1_c, loop2_c, check_mode: CheckMode): + def check(checker: Checker): + if loop1_c.next() != loop2_c: + raise SchedulingError( + "expected the second loop to be directly after the first" + ) - compare_ir = LoopIR_Compare() - if not compare_ir.match_stmts(loop1.body, loop2.body): - raise SchedulingError("expected the two loops to have identical bodies") + loop1 = loop1_c._node + loop2 = loop2_c._node - ir, fwd = loop1_c._child_node("hi")._replace(loop2.hi) - ir, fwd_del = fwd(loop2_c)._delete() + if checker == "static": + try: + Check_ExprEqvInContext( + loop1_c.get_root(), loop1.hi, [loop1], loop2.lo, [loop2] + ) + except Exception as e: + raise SchedulingError( + f"expected the first loop upper bound {loop1.hi} to be the same as the second loop lower bound {loop2.lo}" + ) - return ir, _compose(fwd_del, fwd) + compare_ir = LoopIR_Compare() + if not compare_ir.match_stmts(loop1.body, loop2.body): + raise SchedulingError("expected the two loops to have identical bodies") + ir, fwd = loop1_c._child_node("hi")._replace(loop2.hi) + ir, fwd_del = fwd(loop2_c)._delete() -def DoCutLoop(loop_c, cut_point): - s = loop_c._node + if checker == "dynamic": + fuzz(loop1_c.as_block().expand(0, 1), fwd_del(fwd(loop1_c)).as_block()) - assert isinstance(s, LoopIR.For) + return ir, _compose(fwd_del, fwd) - ir = loop_c.get_root() + return do_check(check, check_mode) - try: - Check_CompareExprs(ir, [s], cut_point, ">=", s.lo) - except SchedulingError: - raise SchedulingError(f"Expected `lo` <= `cut_point`") - try: - Check_CompareExprs(ir, [s], s.hi, ">=", cut_point) - except SchedulingError: - raise SchedulingError(f"Expected `cut_point` <= `hi`") +def DoCutLoop(loop_c, cut_point, check_mode: CheckMode): + def check(checker: Checker): + s = loop_c._node - ir, fwd1 = loop_c._child_node("hi")._replace(cut_point) - loop2 = Alpha_Rename([s.update(lo=cut_point)]).result()[0] - ir, fwd2 = fwd1(loop_c).after()._insert([loop2]) - fwd = _compose(fwd2, fwd1) + assert isinstance(s, LoopIR.For) - return ir, fwd + ir = loop_c.get_root() + if checker == "static": + try: + Check_CompareExprs(ir, [s], cut_point, ">=", s.lo) + except SchedulingError: + raise SchedulingError(f"Expected `lo` <= `cut_point`") -def DoShiftLoop(loop_c, new_lo): - s = loop_c._node + try: + Check_CompareExprs(ir, [s], s.hi, ">=", cut_point) + except SchedulingError: + raise SchedulingError(f"Expected `cut_point` <= `hi`") - assert isinstance(s, LoopIR.For) + ir, fwd1 = loop_c._child_node("hi")._replace(cut_point) + loop2 = Alpha_Rename([s.update(lo=cut_point)]).result()[0] + ir, fwd2 = fwd1(loop_c).after()._insert([loop2]) + fwd = _compose(fwd2, fwd1) - try: - Check_IsNonNegativeExpr( - loop_c.get_root(), - [s], - new_lo, - ) - except SchedulingError: - raise SchedulingError(f"Expected 0 <= `new_lo`") + if checker == "dynamic": + fuzz(loop_c.as_block(), fwd(loop_c).as_block().expand(0, 1)) - loop_length = LoopIR.BinOp("-", s.hi, s.lo, T.index, s.srcinfo) - new_hi = LoopIR.BinOp("+", new_lo, loop_length, T.index, s.srcinfo) + return ir, fwd - ir, fwd1 = loop_c._child_node("lo")._replace(new_lo) - ir, fwd2 = fwd1(loop_c)._child_node("hi")._replace(new_hi) - fwd12 = _compose(fwd2, fwd1) + return do_check(check, check_mode) - # all uses of the loop iteration in the second body need - # to be offset by (`lo` - `new_lo``) - loop_iter = s.iter - iter_node = LoopIR.Read(loop_iter, [], T.index, s.srcinfo) - iter_offset = LoopIR.BinOp("-", s.lo, new_lo, T.index, s.srcinfo) - new_iter = LoopIR.BinOp("+", iter_node, iter_offset, T.index, s.srcinfo) - ir, fwd = _replace_reads( - ir, - fwd12, - loop_c, - loop_iter, - lambda _: new_iter, - only_replace_attrs=False, - ) +def DoShiftLoop(loop_c, new_lo, check_mode: CheckMode): + def check(checker: Checker): + s = loop_c._node - return ir, fwd + assert isinstance(s, LoopIR.For) + + if checker == "static": + try: + Check_IsNonNegativeExpr( + loop_c.get_root(), + [s], + new_lo, + ) + except SchedulingError: + raise SchedulingError(f"Expected 0 <= `new_lo`") + + loop_length = LoopIR.BinOp("-", s.hi, s.lo, T.index, s.srcinfo) + new_hi = LoopIR.BinOp("+", new_lo, loop_length, T.index, s.srcinfo) + + ir, fwd1 = loop_c._child_node("lo")._replace(new_lo) + ir, fwd2 = fwd1(loop_c)._child_node("hi")._replace(new_hi) + fwd12 = _compose(fwd2, fwd1) + + # all uses of the loop iteration in the second body need + # to be offset by (`lo` - `new_lo``) + loop_iter = s.iter + iter_node = LoopIR.Read(loop_iter, [], T.index, s.srcinfo) + iter_offset = LoopIR.BinOp("-", s.lo, new_lo, T.index, s.srcinfo) + new_iter = LoopIR.BinOp("+", iter_node, iter_offset, T.index, s.srcinfo) + + ir, fwd = _replace_reads( + ir, + fwd12, + loop_c, + loop_iter, + lambda _: new_iter, + only_replace_attrs=False, + ) + + if checker == "dynamic": + fuzz(loop_c.as_block(), fwd(loop_c).as_block()) + return ir, fwd + + return do_check(check, check_mode) def DoProductLoop(outer_loop_c, new_name): @@ -661,92 +816,259 @@ def mk_inline_expr(e): def DoDivideWithRecompute( - loop_cursor, outer_hi, outer_stride: int, iter_o: str, iter_i: str + loop_cursor, + outer_hi, + outer_stride: int, + iter_o: str, + iter_i: str, + check_mode: CheckMode, ): - proc = loop_cursor.get_root() - loop = loop_cursor._node - srcinfo = loop.srcinfo + def check(checker: Checker): + proc = loop_cursor.get_root() + loop = loop_cursor._node + srcinfo = loop.srcinfo - assert isinstance(loop, LoopIR.For) - assert isinstance(outer_hi, LoopIR.expr) - Check_IsIdempotent(proc, loop.body) + assert isinstance(loop, LoopIR.For) + assert isinstance(outer_hi, LoopIR.expr) - def rd(i): - return LoopIR.Read(i, [], T.index, srcinfo) + if checker == "static": + Check_IsIdempotent(proc, loop.body) - def cnst(intval): - return LoopIR.Const(intval, T.int, srcinfo) + def rd(i): + return LoopIR.Read(i, [], T.index, srcinfo) - def szop(op, lhs, rhs): - return LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) + def cnst(intval): + return LoopIR.Const(intval, T.int, srcinfo) - sym_o = Sym(iter_o) - sym_i = Sym(iter_i) - x = cnst(outer_stride) + def szop(op, lhs, rhs): + return LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) - if ( - isinstance(outer_hi, LoopIR.BinOp) - and outer_hi.op == "/" - and isinstance(outer_hi.rhs, LoopIR.Const) - and outer_hi.rhs.val == outer_stride - ): - N_before_recompute = szop("-", outer_hi.lhs, szop("%", outer_hi.lhs, x)) - else: - N_before_recompute = szop("*", outer_hi, x) + sym_o = Sym(iter_o) + sym_i = Sym(iter_i) + x = cnst(outer_stride) - N_recompute = LoopIR.BinOp("-", loop.hi, N_before_recompute, T.index, srcinfo) - try: - Check_IsNonNegativeExpr(proc, [loop], N_recompute) - except SchedulingError: - raise SchedulingError(f"outer_hi * outer_stride exceeds loop's hi {loop.hi}") + if ( + isinstance(outer_hi, LoopIR.BinOp) + and outer_hi.op == "/" + and isinstance(outer_hi.rhs, LoopIR.Const) + and outer_hi.rhs.val == outer_stride + ): + N_before_recompute = szop("-", outer_hi.lhs, szop("%", outer_hi.lhs, x)) + else: + N_before_recompute = szop("*", outer_hi, x) - hi_o = outer_hi - hi_i = szop("+", x, N_recompute) + N_recompute = LoopIR.BinOp("-", loop.hi, N_before_recompute, T.index, srcinfo) + if checker == "static": + try: + Check_IsNonNegativeExpr(proc, [loop], N_recompute) + except SchedulingError: + raise SchedulingError( + f"outer_hi * outer_stride exceeds loop's hi {loop.hi}" + ) - # turn current loop into outer loop - ir, fwd = loop_cursor._child_node("iter")._replace(sym_o) - ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(hi_o) - fwd = _compose(fwd_repl, fwd) + hi_o = outer_hi + hi_i = szop("+", x, N_recompute) - # wrap body in inner loop - def inner_wrapper(body): - return LoopIR.For( - sym_i, - LoopIR.Const(0, T.index, srcinfo), - hi_i, - body, - LoopIR.Seq(), - srcinfo, - ) + # turn current loop into outer loop + ir, fwd = loop_cursor._child_node("iter")._replace(sym_o) + ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(hi_o) + fwd = _compose(fwd_repl, fwd) - ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) + # wrap body in inner loop + def inner_wrapper(body): + return LoopIR.For( + sym_i, + LoopIR.Const(0, T.index, srcinfo), + hi_i, + body, + LoopIR.Seq(), + srcinfo, + ) - # replace the iteration variable in the body - def mk_iter(_): - return szop("+", szop("*", rd(sym_o), x), rd(sym_i)) + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) - ir, fwd = _replace_reads( - ir, - fwd, - loop_cursor, - loop.iter, - mk_iter, - only_replace_attrs=False, - ) + # replace the iteration variable in the body + def mk_iter(_): + return szop("+", szop("*", rd(sym_o), x), rd(sym_i)) - return ir, fwd + ir, fwd = _replace_reads( + ir, + fwd, + loop_cursor, + loop.iter, + mk_iter, + only_replace_attrs=False, + ) + if checker == "dynamic": + fuzz(loop_cursor.as_block(), fwd(loop_cursor).as_block()) + + return ir, fwd + + return do_check(check, check_mode) def DoDivideLoop( - loop_cursor, quot, outer_iter, inner_iter, tail="guard", perfect=False + loop_cursor, + quot, + outer_iter, + inner_iter, + check_mode: CheckMode, + tail="guard", + perfect=False, ): + def check(checker: Checker): + loop = loop_cursor._node + N = loop.hi + outer_i = Sym(outer_iter) + inner_i = Sym(inner_iter) + srcinfo = loop.srcinfo + tail_strategy = "perfect" if perfect else tail + + if not is_const_zero(loop.lo): + raise SchedulingError( + f"expected the lower bound of the loop to be zero, got {loop.lo}." + ) + + def substitute(srcinfo): + cnst = lambda x: LoopIR.Const(x, T.int, srcinfo) + rd = lambda x: LoopIR.Read(x, [], T.index, srcinfo) + op = lambda op, lhs, rhs: LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) + + return op("+", op("*", cnst(quot), rd(outer_i)), rd(inner_i)) + + # short-hands for sanity + def boolop(op, lhs, rhs, typ): + return LoopIR.BinOp(op, lhs, rhs, typ, srcinfo) + + def szop(op, lhs, rhs): + return LoopIR.BinOp(op, lhs, rhs, lhs.type, srcinfo) + + def cnst(intval): + return LoopIR.Const(intval, T.int, srcinfo) + + def rd(i): + return LoopIR.Read(i, [], T.index, srcinfo) + + def ceildiv(lhs, rhs): + assert isinstance(rhs, LoopIR.Const) and rhs.val > 0 + rhs_1 = cnst(rhs.val - 1) + return szop("/", szop("+", lhs, rhs_1), rhs) + + # determine hi and lo loop bounds + inner_hi = cnst(quot) + if tail_strategy in ["guard"]: + outer_hi = ceildiv(N, inner_hi) + elif tail_strategy in ["cut", "cut_and_guard"]: + outer_hi = szop("/", N, inner_hi) # floor div + elif tail_strategy == "perfect": + if checker == "static": + ir = loop_cursor.get_root() + loop = loop_cursor._node + Check_IsDivisible(ir, [loop], N, quot) + outer_hi = divide_expr(N, quot) + else: + assert False, f"bad tail strategy: {tail_strategy}" + + # turn current loop into outer loop + ir, fwd = loop_cursor._child_node("iter")._replace(outer_i) + ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(outer_hi) + fwd = _compose(fwd_repl, fwd) + + # wrap body in a guard + if tail_strategy == "guard": + idx_sub = substitute(srcinfo) + + def guard_wrapper(body): + cond = boolop("<", idx_sub, N, T.bool) + return LoopIR.If(cond, body, [], srcinfo) + + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(guard_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + + # wrap body in inner loop + def inner_wrapper(body): + return LoopIR.For( + inner_i, + LoopIR.Const(0, T.index, srcinfo), + inner_hi, + body, + loop.loop_mode, + srcinfo, + ) + + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + + # replace the iteration variable in the body + def mk_main_iter(c): + return substitute(c._node.srcinfo) + + ir, fwd = _replace_reads( + ir, + fwd, + loop_cursor, + loop.iter, + mk_main_iter, + only_replace_attrs=False, + ) + + # add the tail case + if tail_strategy in ["cut", "cut_and_guard"]: + cut_i = Sym(inner_iter) + Ntail = szop("%", N, inner_hi) + + # in the tail loop we want the iteration variable to + # be mapped instead to (Ncut*Q + cut_i) + cut_tail_sub = szop("+", rd(cut_i), szop("*", outer_hi, inner_hi)) + + cut_body = Alpha_Rename(loop.body).result() + env = {loop.iter: cut_tail_sub} + cut_body = SubstArgs(cut_body, env).result() + + cut_s = LoopIR.For( + cut_i, + LoopIR.Const(0, T.index, srcinfo), + Ntail, + cut_body, + loop.loop_mode, + srcinfo, + ) + if tail_strategy == "cut_and_guard": + cond = boolop(">", Ntail, LoopIR.Const(0, T.int, srcinfo), T.bool) + cut_s = LoopIR.If(cond, [cut_s], [], srcinfo) + + ir, fwd_ins = fwd(loop_cursor).after()._insert([cut_s]) + fwd = _compose(fwd_ins, fwd) + if checker == "dynamic": + fuzz( + loop_cursor.as_block(), + ( + fwd(loop_cursor).as_block().expand(0, 1) + if tail_strategy in ["cut", "cut_and_guard"] + else fwd(loop_cursor).as_block() + ), + ) + + return ir, fwd + + return do_check(check, check_mode) + + +def DoDivideLoopMin( + loop_cursor, + quot, + outer_iter, + inner_iter, + check_mode: CheckMode, +): + if check_mode != "dynamic": + raise SchedulingError("cannot use min tail strategy without chexo") loop = loop_cursor._node N = loop.hi outer_i = Sym(outer_iter) inner_i = Sym(inner_iter) srcinfo = loop.srcinfo - tail_strategy = "perfect" if perfect else tail if not is_const_zero(loop.lo): raise SchedulingError( @@ -754,72 +1076,68 @@ def DoDivideLoop( ) def substitute(srcinfo): - cnst = lambda x: LoopIR.Const(x, T.int, srcinfo) rd = lambda x: LoopIR.Read(x, [], T.index, srcinfo) op = lambda op, lhs, rhs: LoopIR.BinOp(op, lhs, rhs, T.index, srcinfo) - return op("+", op("*", cnst(quot), rd(outer_i)), rd(inner_i)) + return op("+", op("*", quot, rd(outer_i)), rd(inner_i)) # short-hands for sanity def boolop(op, lhs, rhs, typ): return LoopIR.BinOp(op, lhs, rhs, typ, srcinfo) - def szop(op, lhs, rhs): - return LoopIR.BinOp(op, lhs, rhs, lhs.type, srcinfo) - def cnst(intval): return LoopIR.Const(intval, T.int, srcinfo) + def szop(op, lhs, rhs): + return LoopIR.BinOp(op, lhs, rhs, lhs.type, srcinfo) + def rd(i): return LoopIR.Read(i, [], T.index, srcinfo) def ceildiv(lhs, rhs): - assert isinstance(rhs, LoopIR.Const) and rhs.val > 0 - rhs_1 = cnst(rhs.val - 1) + rhs_1 = LoopIR.BinOp("-", rhs, cnst(1), rhs.type, srcinfo) return szop("/", szop("+", lhs, rhs_1), rhs) # determine hi and lo loop bounds - inner_hi = cnst(quot) - if tail_strategy == "guard": - outer_hi = ceildiv(N, inner_hi) - elif tail_strategy in ["cut", "cut_and_guard"]: - outer_hi = szop("/", N, inner_hi) # floor div - elif tail_strategy == "perfect": - ir = loop_cursor.get_root() - loop = loop_cursor._node - Check_IsDivisible(ir, [loop], N, quot) - outer_hi = divide_expr(N, quot) - else: - assert False, f"bad tail strategy: {tail_strategy}" + inner_hi = quot + outer_hi = ceildiv(N, inner_hi) # turn current loop into outer loop ir, fwd = loop_cursor._child_node("iter")._replace(outer_i) ir, fwd_repl = fwd(loop_cursor)._child_node("hi")._replace(outer_hi) fwd = _compose(fwd_repl, fwd) - # wrap body in a guard - if tail_strategy == "guard": - idx_sub = substitute(srcinfo) - - def guard_wrapper(body): - cond = boolop("<", idx_sub, N, T.bool) - return LoopIR.If(cond, body, [], srcinfo) - - ir, fwd_wrap = fwd(loop_cursor).body()._wrap(guard_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - - # wrap body in inner loop - def inner_wrapper(body): + def inner_wrapper_min(body): return LoopIR.For( inner_i, LoopIR.Const(0, T.index, srcinfo), - inner_hi, + LoopIR.Extern( + intmin, + [ + inner_hi, + LoopIR.BinOp( + "-", + N, + LoopIR.BinOp( + "*", + quot, + LoopIR.Read(outer_i, [], T.index, srcinfo), + T.index, + srcinfo, + ), + N.type, + srcinfo, + ), + ], + T.size, + srcinfo, + ), body, loop.loop_mode, srcinfo, ) - ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper, "body") + ir, fwd_wrap = fwd(loop_cursor).body()._wrap(inner_wrapper_min, "body") fwd = _compose(fwd_wrap, fwd) # replace the iteration variable in the body @@ -835,34 +1153,7 @@ def mk_main_iter(c): only_replace_attrs=False, ) - # add the tail case - if tail_strategy in ["cut", "cut_and_guard"]: - cut_i = Sym(inner_iter) - Ntail = szop("%", N, inner_hi) - - # in the tail loop we want the iteration variable to - # be mapped instead to (Ncut*Q + cut_i) - cut_tail_sub = szop("+", rd(cut_i), szop("*", outer_hi, inner_hi)) - - cut_body = Alpha_Rename(loop.body).result() - env = {loop.iter: cut_tail_sub} - cut_body = SubstArgs(cut_body, env).result() - - cut_s = LoopIR.For( - cut_i, - LoopIR.Const(0, T.index, srcinfo), - Ntail, - cut_body, - loop.loop_mode, - srcinfo, - ) - if tail_strategy == "cut_and_guard": - cond = boolop(">", Ntail, LoopIR.Const(0, T.int, srcinfo), T.bool) - cut_s = LoopIR.If(cond, [cut_s], [], srcinfo) - - ir, fwd_ins = fwd(loop_cursor).after()._insert([cut_s]) - fwd = _compose(fwd_ins, fwd) - + fuzz(loop_cursor.as_block(), fwd(loop_cursor).as_block()) return ir, fwd @@ -1142,20 +1433,31 @@ def mk_write(c): return ir, fwd -def DoConfigWrite(stmt_cursor, config, field, expr, before=False): - assert isinstance(expr, (LoopIR.Read, LoopIR.StrideExpr, LoopIR.Const)) - s = stmt_cursor._node +def DoConfigWrite( + stmt_cursor, config, field, expr, check_mode: CheckMode, before=False +): + def check(checker: Checker): + assert isinstance(expr, (LoopIR.Read, LoopIR.StrideExpr, LoopIR.Const)) + s = stmt_cursor._node - cw_s = LoopIR.WriteConfig(config, field, expr, s.srcinfo) + cw_s = LoopIR.WriteConfig(config, field, expr, s.srcinfo) - if before: - ir, fwd = stmt_cursor.before()._insert([cw_s]) - else: - ir, fwd = stmt_cursor.after()._insert([cw_s]) + if before: + ir, fwd = stmt_cursor.before()._insert([cw_s]) + else: + ir, fwd = stmt_cursor.after()._insert([cw_s]) - cfg = Check_DeleteConfigWrite(ir, [cw_s]) + if checker == "static": + cfg = Check_DeleteConfigWrite(ir, [cw_s]) + else: + cfg = None + fuzz( + stmt_cursor.as_block(), + fwd(stmt_cursor).as_block().expand(*((1, 0) if before else (0, 1))), + ) + return ir, fwd, cfg - return ir, fwd, cfg + return do_check(check, check_mode) # --------------------------------------------------------------------------- # @@ -1163,28 +1465,38 @@ def DoConfigWrite(stmt_cursor, config, field, expr, before=False): # Bind Expression scheduling directive -def DoBindConfig(config, field, expr_cursor): - e = expr_cursor._node - assert isinstance(e, LoopIR.Read) +def DoBindConfig(config, field, expr_cursor, check_mode): + def check(checker: Checker): + e = expr_cursor._node + if checker == "static": + assert isinstance(e, LoopIR.Read) - c = expr_cursor - while not isinstance(c._node, LoopIR.stmt): - c = c.parent() + c = expr_cursor + while not isinstance(c._node, LoopIR.stmt): + c = c.parent() - cfg_write_s = LoopIR.WriteConfig(config, field, e, e.srcinfo) - ir, fwd = c.before()._insert([cfg_write_s]) + cfg_write_s = LoopIR.WriteConfig(config, field, e, e.srcinfo) + ir, fwd = c.before()._insert([cfg_write_s]) - mod_cfg = Check_DeleteConfigWrite(ir, [cfg_write_s]) + if checker == "static": + mod_cfg = Check_DeleteConfigWrite(ir, [cfg_write_s]) + else: + mod_cfg = None - cfg_f_type = config.lookup_type(field) - cfg_read_e = LoopIR.ReadConfig(config, field, cfg_f_type, e.srcinfo) - if isinstance(expr_cursor.parent()._node, LoopIR.Call): - cfg_read_e = [cfg_read_e] - ir, fwd_repl = fwd(expr_cursor)._replace(cfg_read_e) - fwd = _compose(fwd_repl, fwd) + cfg_f_type = config.lookup_type(field) + cfg_read_e = LoopIR.ReadConfig(config, field, cfg_f_type, e.srcinfo) + if isinstance(expr_cursor.parent()._node, LoopIR.Call): + cfg_read_e = [cfg_read_e] + ir, fwd_repl = fwd(expr_cursor)._replace(cfg_read_e) + fwd = _compose(fwd_repl, fwd) - Check_Aliasing(ir) - return ir, fwd, mod_cfg + if checker == "static": + Check_Aliasing(ir) + else: + fuzz(c.as_block(), fwd(c).as_block().expand(1, 0)) + return ir, fwd, mod_cfg + + return do_check(check, check_mode) def DoCommuteExpr(expr_cursors): @@ -1234,212 +1546,247 @@ def match_parent(c1, c2): return c1, c2 -def DoRewriteExpr(expr_cursor, new_expr): - proc = expr_cursor.get_root() - s = get_enclosing_stmt_cursor(expr_cursor)._node - Check_ExprEqvInContext(proc, expr_cursor._node, [s], new_expr, [s]) - return expr_cursor._replace(new_expr) - - -def DoBindExpr(new_name, expr_cursors): - assert expr_cursors - - expr = expr_cursors[0]._node - assert isinstance(expr, LoopIR.expr) - assert expr.type.is_numeric() - - expr_reads = [name for (name, typ) in get_reads_of_expr(expr)] - # TODO: dirty hack. need real CSE-equality (i.e. modulo srcinfo) - expr_cursors = [c for c in expr_cursors if str(c._node) == str(expr)] +def DoRewriteExpr(expr_cursor, new_expr, check_mode): + def check(checker: Checker): + ir, fwd = expr_cursor._replace(new_expr) + if checker == "static": + proc = expr_cursor.get_root() + s = get_enclosing_stmt_cursor(expr_cursor)._node + Check_ExprEqvInContext(proc, expr_cursor._node, [s], new_expr, [s]) + else: + stmt_cursor = get_enclosing_stmt_cursor(expr_cursor) + fuzz(stmt_cursor.as_block(), fwd(stmt_cursor).as_block()) + return ir, fwd - init_s = get_enclosing_stmt_cursor(expr_cursors[0]) - if len(expr_cursors) > 1: - # TODO: Currently assume expr cursors is sorted in order - init_s, _ = match_parent(init_s, expr_cursors[-1]) + return do_check(check, check_mode) + + +def DoBindExpr(new_name, expr_cursors, check_mode: CheckMode): + def check(checker: Checker): + assert expr_cursors + + expr = expr_cursors[0]._node + assert isinstance(expr, LoopIR.expr) + et = expr.type if expr.type.is_numeric() else T.i32 + if checker == "static": + assert expr.type.is_numeric() + + expr_reads = [name for (name, typ) in get_reads_of_expr(expr)] + # TODO: dirty hack. need real CSE-equality (i.e. modulo srcinfo) + expr_cursors_eq = [c for c in expr_cursors if str(c._node) == str(expr)] + + init_s = get_enclosing_stmt_cursor(expr_cursors_eq[0]) + if len(expr_cursors_eq) > 1: + # TODO: Currently assume expr cursors is sorted in order + init_s, _ = match_parent(init_s, expr_cursors_eq[-1]) + + new_name_sym = Sym(new_name) + alloc_s = LoopIR.Alloc(new_name_sym, et.basetype(), DRAM, expr.srcinfo) + assign_s = LoopIR.Assign(new_name_sym, et.basetype(), [], expr, expr.srcinfo) + ir, fwd = init_s.before()._insert([alloc_s, assign_s]) + + new_read = LoopIR.Read(new_name_sym, [], expr.type, expr.srcinfo) + first_write_c = None + for c in get_rest_of_block(init_s, inclusive=True): + if checker == "static": + for block in match_pattern(c, "_ = _") + match_pattern(c, "_ += _"): + assert len(block) == 1 + sc = block[0] + if sc._node.name in expr_reads: + first_write_c = sc + break + + if first_write_c and isinstance(c._node, (LoopIR.For, LoopIR.If)): + # Potentially unsafe to partially bind, err on side of caution for now + break + + while expr_cursors_eq and c.is_ancestor_of(expr_cursors_eq[0]): + ir, fwd_repl = _replace_helper( + fwd(expr_cursors_eq[0]), new_read, only_replace_attrs=False + ) + fwd = _compose(fwd_repl, fwd) + expr_cursors_eq.pop(0) - new_name = Sym(new_name) - alloc_s = LoopIR.Alloc(new_name, expr.type.basetype(), DRAM, expr.srcinfo) - assign_s = LoopIR.Assign(new_name, expr.type.basetype(), [], expr, expr.srcinfo) - ir, fwd = init_s.before()._insert([alloc_s, assign_s]) - - new_read = LoopIR.Read(new_name, [], expr.type, expr.srcinfo) - first_write_c = None - for c in get_rest_of_block(init_s, inclusive=True): - for block in match_pattern(c, "_ = _") + match_pattern(c, "_ += _"): - assert len(block) == 1 - sc = block[0] - if sc._node.name in expr_reads: - first_write_c = sc + if first_write_c: break - if first_write_c and isinstance(c._node, (LoopIR.For, LoopIR.If)): - # Potentially unsafe to partially bind, err on side of caution for now - break + if len(expr_cursors_eq) > 0: + raise SchedulingError("Unsafe to bind all of the provided exprs.") - while expr_cursors and c.is_ancestor_of(expr_cursors[0]): - ir, fwd_repl = _replace_helper( - fwd(expr_cursors[0]), new_read, only_replace_attrs=False + if checker == "static": + Check_Aliasing(ir) + else: + fuzz( + init_s.as_block().expand(0, None), + fwd(init_s).as_block().expand(1, None), ) - fwd = _compose(fwd_repl, fwd) - expr_cursors.pop(0) - - if first_write_c: - break - - if len(expr_cursors) > 0: - raise SchedulingError("Unsafe to bind all of the provided exprs.") - - Check_Aliasing(ir) - return ir, fwd - + return ir, fwd -def DoLiftScope(inner_c): - inner_s = inner_c._node - assert isinstance(inner_s, (LoopIR.If, LoopIR.For)) - target_type = "if statement" if isinstance(inner_s, LoopIR.If) else "for loop" + return do_check(check, check_mode) + + +def DoLiftScope(inner_c, check_mode: CheckMode): + def check(checker: Checker): + inner_s = inner_c._node + assert isinstance(inner_s, (LoopIR.If, LoopIR.For)) + target_type = "if statement" if isinstance(inner_s, LoopIR.If) else "for loop" + + outer_c = inner_c.parent() + if outer_c.root() == outer_c: + raise SchedulingError("Cannot lift scope of top-level statement") + outer_s = outer_c._node + + ir, fwd = inner_c.get_root(), lambda x: x + + if isinstance(outer_s, LoopIR.If): + + def if_wrapper(body, insert_orelse=False): + src = outer_s.srcinfo + # this is needed because _replace expects a non-zero length block + orelse = [LoopIR.Pass(src)] if insert_orelse else [] + return LoopIR.If(outer_s.cond, body, orelse, src) + + def orelse_wrapper(orelse): + src = outer_s.srcinfo + body = [LoopIR.Pass(src)] + return LoopIR.If(outer_s.cond, body, orelse, src) + + if isinstance(inner_s, LoopIR.If): + if inner_s in outer_s.body: + # if INNER: + # if OUTER: if OUTER: A + # if INNER: A else: C + # else: B ~> else: + # else: C if OUTER: B + # else: C + if len(outer_s.body) > 1: + raise SchedulingError( + f"expected {target_type} to be directly nested in parent" + ) - outer_c = inner_c.parent() - if outer_c.root() == outer_c: - raise SchedulingError("Cannot lift scope of top-level statement") - outer_s = outer_c._node + blk_c = outer_s.orelse + wrapper = lambda body: if_wrapper(body, insert_orelse=bool(blk_c)) - ir, fwd = inner_c.get_root(), lambda x: x + ir, fwd = inner_c.body()._wrap(wrapper, "body") + if blk_c: + ir, fwd_repl = fwd(inner_c).body()[0].orelse()._replace(blk_c) + fwd = _compose(fwd_repl, fwd) - if isinstance(outer_s, LoopIR.If): + if inner_s.orelse: + ir, fwd_wrap = fwd(inner_c).orelse()._wrap(wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + if blk_c: + ir, fwd_repl = ( + fwd(inner_c).orelse()[0].orelse()._replace(blk_c) + ) + fwd = _compose(fwd_repl, fwd) + else: + # if INNER: + # if OUTER: A if OUTER: A + # else: else: B + # if INNER: B ~> else: + # else: C if OUTER: A + # else: C + assert inner_s in outer_s.orelse + if len(outer_s.orelse) > 1: + raise SchedulingError( + f"expected {target_type} to be directly nested in parent" + ) - def if_wrapper(body, insert_orelse=False): - src = outer_s.srcinfo - # this is needed because _replace expects a non-zero length block - orelse = [LoopIR.Pass(src)] if insert_orelse else [] - return LoopIR.If(outer_s.cond, body, orelse, src) + blk_a = outer_s.body - def orelse_wrapper(orelse): - src = outer_s.srcinfo - body = [LoopIR.Pass(src)] - return LoopIR.If(outer_s.cond, body, orelse, src) + ir, fwd = inner_c.body()._wrap(orelse_wrapper, "orelse") + ir, fwd_repl = fwd(inner_c).body()[0].body()._replace(blk_a) + fwd = _compose(fwd_repl, fwd) - if isinstance(inner_s, LoopIR.If): - if inner_s in outer_s.body: - # if INNER: - # if OUTER: if OUTER: A - # if INNER: A else: C - # else: B ~> else: - # else: C if OUTER: B - # else: C + if inner_s.orelse: + ir, fwd_wrap = ( + fwd(inner_c).orelse()._wrap(orelse_wrapper, "orelse") + ) + fwd = _compose(fwd_wrap, fwd) + ir, fwd_repl = fwd(inner_c).orelse()[0].body()._replace(blk_a) + fwd = _compose(fwd_repl, fwd) + elif isinstance(inner_s, LoopIR.For): + # if OUTER: for INNER in _: + # for INNER in _: A ~> if OUTER: A if len(outer_s.body) > 1: raise SchedulingError( f"expected {target_type} to be directly nested in parent" ) - blk_c = outer_s.orelse - wrapper = lambda body: if_wrapper(body, insert_orelse=bool(blk_c)) - - ir, fwd = inner_c.body()._wrap(wrapper, "body") - if blk_c: - ir, fwd_repl = fwd(inner_c).body()[0].orelse()._replace(blk_c) - fwd = _compose(fwd_repl, fwd) - - if inner_s.orelse: - ir, fwd_wrap = fwd(inner_c).orelse()._wrap(wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - if blk_c: - ir, fwd_repl = fwd(inner_c).orelse()[0].orelse()._replace(blk_c) - fwd = _compose(fwd_repl, fwd) - else: - # if INNER: - # if OUTER: A if OUTER: A - # else: else: B - # if INNER: B ~> else: - # else: C if OUTER: A - # else: C - assert inner_s in outer_s.orelse - if len(outer_s.orelse) > 1: + if outer_s.orelse: raise SchedulingError( - f"expected {target_type} to be directly nested in parent" + "cannot lift for loop when if has an orelse clause" ) - blk_a = outer_s.body + ir, fwd = inner_c.body()._move(inner_c.after()) + ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_move = fwd(outer_c)._move(fwd(inner_c).body()[0].after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_del = fwd(inner_c).body()[0]._delete() + fwd = _compose(fwd_del, fwd) - ir, fwd = inner_c.body()._wrap(orelse_wrapper, "orelse") - ir, fwd_repl = fwd(inner_c).body()[0].body()._replace(blk_a) - fwd = _compose(fwd_repl, fwd) + if checker == "dynamic": + fuzz(outer_c.as_block(), fwd(inner_c).as_block()) + return ir, fwd - if inner_s.orelse: - ir, fwd_wrap = fwd(inner_c).orelse()._wrap(orelse_wrapper, "orelse") - fwd = _compose(fwd_wrap, fwd) - ir, fwd_repl = fwd(inner_c).orelse()[0].body()._replace(blk_a) - fwd = _compose(fwd_repl, fwd) - elif isinstance(inner_s, LoopIR.For): - # if OUTER: for INNER in _: - # for INNER in _: A ~> if OUTER: A + elif isinstance(outer_s, LoopIR.For): if len(outer_s.body) > 1: raise SchedulingError( f"expected {target_type} to be directly nested in parent" ) - if outer_s.orelse: - raise SchedulingError( - "cannot lift for loop when if has an orelse clause" - ) + def loop_wrapper(body): + return outer_s.update(body=body) - ir, fwd = inner_c.body()._move(inner_c.after()) - ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_move = fwd(outer_c)._move(fwd(inner_c).body()[0].after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(inner_c).body()[0]._delete() - fwd = _compose(fwd_del, fwd) + if isinstance(inner_s, LoopIR.If): + # for OUTER in _: if INNER: + # if INNER: A ~> for OUTER in _: A + # else: B else: + # for OUTER in _: B + if outer_s.iter in _FV(inner_s.cond): + raise SchedulingError("if statement depends on iteration variable") - return ir, fwd + ir, fwd = inner_c.body()._wrap(loop_wrapper, "body") - elif isinstance(outer_s, LoopIR.For): - if len(outer_s.body) > 1: - raise SchedulingError( - f"expected {target_type} to be directly nested in parent" - ) - - def loop_wrapper(body): - return outer_s.update(body=body) - - if isinstance(inner_s, LoopIR.If): - # for OUTER in _: if INNER: - # if INNER: A ~> for OUTER in _: A - # else: B else: - # for OUTER in _: B - if outer_s.iter in _FV(inner_s.cond): - raise SchedulingError("if statement depends on iteration variable") + if inner_s.orelse: + ir, fwd_wrap = fwd(inner_c).orelse()._wrap(loop_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) + elif isinstance(inner_s, LoopIR.For): + # for OUTER in _: for INNER in _: + # for INNER in _: A ~> for OUTER in _: A + reads = get_reads_of_expr(inner_s.lo) + get_reads_of_expr(inner_s.hi) + if outer_s.iter in [name for name, _ in reads]: + raise SchedulingError( + "inner loop's lo or hi depends on outer loop's iteration variable" + ) - ir, fwd = inner_c.body()._wrap(loop_wrapper, "body") + if checker == "static": + Check_ReorderLoops(inner_c.get_root(), outer_s) + body = inner_c.body() + ir, fwd = inner_c._move(outer_c.after()) + ir, fwd_move = fwd(outer_c)._move(fwd(body).before()) + fwd = _compose(fwd_move, fwd) + ir, fwd_move = fwd(body)._move(fwd(outer_c).body().after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_del = fwd(outer_c).body()[0]._delete() + fwd = _compose(fwd_del, fwd) + if checker == "dynamic": + fuzz(outer_c.as_block(), fwd(inner_c).as_block()) + return ir, fwd - if inner_s.orelse: - ir, fwd_wrap = fwd(inner_c).orelse()._wrap(loop_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) - elif isinstance(inner_s, LoopIR.For): - # for OUTER in _: for INNER in _: - # for INNER in _: A ~> for OUTER in _: A - reads = get_reads_of_expr(inner_s.lo) + get_reads_of_expr(inner_s.hi) - if outer_s.iter in [name for name, _ in reads]: - raise SchedulingError( - "inner loop's lo or hi depends on outer loop's iteration variable" - ) + ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) + fwd = _compose(fwd_move, fwd) + ir, fwd_del = fwd(outer_c)._delete() + fwd = _compose(fwd_del, fwd) - Check_ReorderLoops(inner_c.get_root(), outer_s) - body = inner_c.body() - ir, fwd = inner_c._move(outer_c.after()) - ir, fwd_move = fwd(outer_c)._move(fwd(body).before()) - fwd = _compose(fwd_move, fwd) - ir, fwd_move = fwd(body)._move(fwd(outer_c).body().after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(outer_c).body()[0]._delete() - fwd = _compose(fwd_del, fwd) - return ir, fwd + if checker == "dynamic": + fuzz(outer_c.as_block(), fwd(inner_c).as_block()) - ir, fwd_move = fwd(inner_c)._move(fwd(outer_c).after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(outer_c)._delete() - fwd = _compose(fwd_del, fwd) + return ir, fwd - return ir, fwd + return do_check(check, check_mode) def DoLiftConstant(assign_c, loop_c): @@ -1571,115 +1918,137 @@ def reduces_have_same_constant(s1, s2): return ir, fwd -def DoExpandDim(alloc_cursor, alloc_dim, indexing): - alloc_s = alloc_cursor._node - assert isinstance(alloc_s, LoopIR.Alloc) - assert isinstance(alloc_dim, LoopIR.expr) - assert isinstance(indexing, LoopIR.expr) +def DoExpandDim(alloc_cursor, alloc_dim, indexing, check_mode: CheckMode): + def check(checker: Checker): + alloc_s = alloc_cursor._node + assert isinstance(alloc_s, LoopIR.Alloc) + assert isinstance(alloc_dim, LoopIR.expr) + assert isinstance(indexing, LoopIR.expr) - Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], alloc_dim) + if checker == "static": + Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], alloc_dim) - old_typ = alloc_s.type - new_rngs = [alloc_dim] - if isinstance(old_typ, T.Tensor): - new_rngs += old_typ.shape() - basetyp = old_typ.basetype() - new_typ = T.Tensor(new_rngs, False, basetyp) - new_alloc = alloc_s.update(type=new_typ) + old_typ = alloc_s.type + new_rngs = [alloc_dim] + if isinstance(old_typ, T.Tensor): + new_rngs += old_typ.shape() + basetyp = old_typ.basetype() + new_typ = T.Tensor(new_rngs, False, basetyp) + new_alloc = alloc_s.update(type=new_typ) - ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) + ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) - def mk_read(c): - rd = c._node + def mk_read(c): + rd = c._node - # TODO: do I need to worry about Builtins too? - if isinstance(c.parent()._node, (LoopIR.Call)) and not rd.idx: - raise SchedulingError( - "TODO: Please Contact the developers to fix (i.e. add) " - "support for passing windows to scalar arguments" - ) + # TODO: do I need to worry about Builtins too? + if isinstance(c.parent()._node, (LoopIR.Call)) and not rd.idx: + raise SchedulingError( + "TODO: Please Contact the developers to fix (i.e. add) " + "support for passing windows to scalar arguments" + ) - if isinstance(rd, LoopIR.Read): - return {"idx": [indexing] + rd.idx} - elif isinstance(rd, LoopIR.WindowExpr): - return {"idx": [LoopIR.Point(indexing, rd.srcinfo)] + rd.idx} - else: - raise NotImplementedError( - f"Did not implement {type(rd)}. This may be a bug." - ) + if isinstance(rd, LoopIR.Read): + return {"idx": [indexing] + rd.idx} + elif isinstance(rd, LoopIR.WindowExpr): + return {"idx": [LoopIR.Point(indexing, rd.srcinfo)] + rd.idx} + else: + raise NotImplementedError( + f"Did not implement {type(rd)}. This may be a bug." + ) - def mk_write(c): - s = c._node - return {"idx": [indexing] + s.idx} + def mk_write(c): + s = c._node + return {"idx": [indexing] + s.idx} - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) - after_alloc = [c._node for c in get_rest_of_block(fwd(alloc_cursor))] - Check_Bounds(ir, new_alloc, after_alloc) + after_alloc = [c._node for c in get_rest_of_block(fwd(alloc_cursor))] - return ir, fwd + if checker == "static": + Check_Bounds(ir, new_alloc, after_alloc) + else: + fuzz(get_rest_of_block(alloc_cursor), get_rest_of_block(fwd(alloc_cursor))) + return ir, fwd + return do_check(check, check_mode) -def DoResizeDim(alloc_cursor, dim_idx: int, size: LoopIR.expr, offset: LoopIR.expr): - alloc_s = alloc_cursor._node - alloc_name = alloc_s.name - assert isinstance(alloc_s, LoopIR.Alloc) - assert isinstance(alloc_s.type, T.Tensor) - Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], size) +def DoResizeDim( + alloc_cursor, + dim_idx: int, + size: LoopIR.expr, + offset: LoopIR.expr, + check_mode: CheckMode, +): + def check(checker: Checker): + alloc_s = alloc_cursor._node + alloc_name = alloc_s.name + assert isinstance(alloc_s, LoopIR.Alloc) + assert isinstance(alloc_s.type, T.Tensor) + + if checker == "static": + Check_IsPositiveExpr(alloc_cursor.get_root(), [alloc_s], size) + + ir, fwd = ( + alloc_cursor._child_node("type") + ._child_block("hi")[dim_idx] + ._replace([size]) + ) - ir, fwd = ( - alloc_cursor._child_node("type")._child_block("hi")[dim_idx]._replace([size]) - ) + def mk_read(c): + rd = c._node - def mk_read(c): - rd = c._node + def mk_binop(e): + return LoopIR.BinOp("-", e, offset, offset.type, rd.srcinfo) - def mk_binop(e): - return LoopIR.BinOp("-", e, offset, offset.type, rd.srcinfo) + new_idx = rd.idx.copy() + if isinstance(rd, LoopIR.Read): + new_idx[dim_idx] = mk_binop(rd.idx[dim_idx]) + return {"idx": new_idx} - new_idx = rd.idx.copy() - if isinstance(rd, LoopIR.Read): - new_idx[dim_idx] = mk_binop(rd.idx[dim_idx]) - return {"idx": new_idx} + elif isinstance(rd, LoopIR.WindowExpr): + if isinstance(rd.idx[dim_idx], LoopIR.Point): + new_idx[dim_idx] = LoopIR.Point( + mk_binop(rd.idx[dim_idx].pt), rd.srcinfo + ) + else: + new_idx[dim_idx] = LoopIR.Interval( + mk_binop(rd.idx[dim_idx].lo), + mk_binop(rd.idx[dim_idx].hi), + rd.srcinfo, + ) - elif isinstance(rd, LoopIR.WindowExpr): - if isinstance(rd.idx[dim_idx], LoopIR.Point): - new_idx[dim_idx] = LoopIR.Point( - mk_binop(rd.idx[dim_idx].pt), rd.srcinfo - ) + return {"idx": new_idx} else: - new_idx[dim_idx] = LoopIR.Interval( - mk_binop(rd.idx[dim_idx].lo), - mk_binop(rd.idx[dim_idx].hi), - rd.srcinfo, + raise NotImplementedError( + f"Did not implement {type(rd)}. This may be a bug." ) - return {"idx": new_idx} - else: - raise NotImplementedError( - f"Did not implement {type(rd)}. This may be a bug." + def mk_write(c): + s = c._node + new_idx = s.idx.copy() + new_idx[dim_idx] = LoopIR.BinOp( + "-", s.idx[dim_idx], offset, offset.type, s.srcinfo ) + return {"idx": new_idx} - def mk_write(c): - s = c._node - new_idx = s.idx.copy() - new_idx[dim_idx] = LoopIR.BinOp( - "-", s.idx[dim_idx], offset, offset.type, s.srcinfo - ) - return {"idx": new_idx} + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) + new_alloc_cursor = fwd(alloc_cursor) + after_alloc = [c._node for c in get_rest_of_block(new_alloc_cursor)] - alloc_cursor = fwd(alloc_cursor) - after_alloc = [c._node for c in get_rest_of_block(alloc_cursor)] - Check_Bounds(ir, alloc_cursor._node, after_alloc) + if checker == "static": + Check_Bounds(ir, new_alloc_cursor._node, after_alloc) + else: + fuzz(get_rest_of_block(alloc_cursor), get_rest_of_block(new_alloc_cursor)) + return ir, fwd - return ir, fwd + return do_check(check, check_mode) def DoRearrangeDim(decl_cursor, permute_vector): @@ -1761,63 +2130,69 @@ def mk_stride_expr(c): return ir, fwd -def DoDivideDim(alloc_cursor, dim_idx, quotient): - alloc_s = alloc_cursor._node - alloc_sym = alloc_s.name +def DoDivideDim(alloc_cursor, dim_idx, quotient, check_mode: CheckMode): + def check(checker: Checker): + alloc_s = alloc_cursor._node + alloc_sym = alloc_s.name - assert isinstance(alloc_s, LoopIR.Alloc) - assert isinstance(dim_idx, int) - assert isinstance(quotient, int) + assert isinstance(alloc_s, LoopIR.Alloc) + assert isinstance(dim_idx, int) + assert isinstance(quotient, int) - old_typ = alloc_s.type - old_shp = old_typ.shape() - dim = old_shp[dim_idx] - Check_IsDivisible(alloc_cursor.get_root(), [alloc_s], dim, quotient) - numer = divide_expr(dim, quotient) - new_shp = ( - old_shp[:dim_idx] - + [ - numer, - LoopIR.Const(quotient, T.int, dim.srcinfo), - ] - + old_shp[dim_idx + 1 :] - ) - new_typ = T.Tensor(new_shp, False, old_typ.basetype()) + old_typ = alloc_s.type + old_shp = old_typ.shape() + dim = old_shp[dim_idx] - ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) + if checker == "static": + Check_IsDivisible(alloc_cursor.get_root(), [alloc_s], dim, quotient) + numer = divide_expr(dim, quotient) + new_shp = ( + old_shp[:dim_idx] + + [ + numer, + LoopIR.Const(quotient, T.int, dim.srcinfo), + ] + + old_shp[dim_idx + 1 :] + ) + new_typ = T.Tensor(new_shp, False, old_typ.basetype()) - def remap_idx(idx): - orig_i = idx[dim_idx] - srcinfo = orig_i.srcinfo - quot = LoopIR.Const(quotient, T.int, srcinfo) - hi = LoopIR.BinOp("/", orig_i, quot, orig_i.type, srcinfo) - lo = LoopIR.BinOp("%", orig_i, quot, orig_i.type, srcinfo) - return idx[:dim_idx] + [hi, lo] + idx[dim_idx + 1 :] + ir, fwd = alloc_cursor._child_node("type")._replace(new_typ) - def mk_read(c): - rd = c._node + def remap_idx(idx): + orig_i = idx[dim_idx] + srcinfo = orig_i.srcinfo + quot = LoopIR.Const(quotient, T.int, srcinfo) + hi = LoopIR.BinOp("/", orig_i, quot, orig_i.type, srcinfo) + lo = LoopIR.BinOp("%", orig_i, quot, orig_i.type, srcinfo) + return idx[:dim_idx] + [hi, lo] + idx[dim_idx + 1 :] - if isinstance(rd, LoopIR.Read) and not rd.idx: - raise SchedulingError( - f"Cannot divide {alloc_sym} because buffer is passed as an argument" - ) - elif isinstance(rd, LoopIR.WindowExpr): - raise SchedulingError( - f"Cannot divide {alloc_sym} because the buffer is windowed later on" - ) + def mk_read(c): + rd = c._node - return {"idx": remap_idx(rd.idx)} + if isinstance(rd, LoopIR.Read) and not rd.idx: + raise SchedulingError( + f"Cannot divide {alloc_sym} because buffer is passed as an argument" + ) + elif isinstance(rd, LoopIR.WindowExpr): + raise SchedulingError( + f"Cannot divide {alloc_sym} because the buffer is windowed later on" + ) - def mk_write(c): - s = c._node - return {"idx": remap_idx(s.idx)} + return {"idx": remap_idx(rd.idx)} - # TODO: add better iteration primitive - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) + def mk_write(c): + s = c._node + return {"idx": remap_idx(s.idx)} + + # TODO: add better iteration primitive + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_s.name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_s.name, mk_write) + if checker == "dynamic": + fuzz(get_rest_of_block(alloc_cursor), get_rest_of_block(fwd(alloc_cursor))) + return ir, fwd - return ir, fwd + return do_check(check, check_mode) def DoMultiplyDim(alloc_cursor, hi_idx, lo_idx): @@ -2298,146 +2673,163 @@ def _stmt(s): return all(_stmt(s) for s in stmts) -def DoRemoveLoop(loop, unsafe_disable_check): - s = loop._node +def DoRemoveLoop(loop, unsafe_disable_check, check_mode): + def check(checker: Checker): + s = loop._node - # Check if we can remove the loop. Conditions are: - # 1. Body does not depend on the loop iteration variable - if s.iter in _FV(s.body): - raise SchedulingError( - f"Cannot remove loop, {s.iter} is not " "free in the loop body." - ) + # Check if we can remove the loop. Conditions are: + # 1. Body does not depend on the loop iteration variable + if s.iter in _FV(s.body): + raise SchedulingError( + f"Cannot remove loop, {s.iter} is not " "free in the loop body." + ) - # 2. Body is idempotent - if not unsafe_disable_check: - Check_IsIdempotent(loop.get_root(), [s]) + # 2. Body is idempotent + if not unsafe_disable_check and checker == "static": + Check_IsIdempotent(loop.get_root(), [s]) - # 3. The loop runs at least once; - # If not, then place a guard around the statement - ir, fwd = loop.get_root(), lambda x: x - try: - Check_IsPositiveExpr(loop.get_root(), [s], s.hi) - except SchedulingError: - cond = LoopIR.BinOp(">", s.hi, s.lo, T.bool, s.srcinfo) + # 3. The loop runs at least once; + # If not, then place a guard around the statement + ir, fwd_move = loop.body()._move(loop.after()) + ir, fwd_del = fwd(loop)._delete() + fwd = _compose(fwd_del, fwd_move) + try: + if checker == "static": + Check_IsPositiveExpr(loop.get_root(), [s], s.hi) + else: + fuzz(loop.as_block(), fwd(loop.body())) + except SchedulingError: + cond = LoopIR.BinOp(">", s.hi, s.lo, T.bool, s.srcinfo) - def wrapper(body): - return LoopIR.If(cond, body, [], s.srcinfo) + def wrapper(body): + return LoopIR.If(cond, body, [], s.srcinfo) - ir, fwd = loop.body()._wrap(wrapper, "body") + ir, fwd = fwd(loop.body())._wrap(wrapper, "body") + if checker == "dynamic": + fuzz(loop.as_block(), fwd(loop.body()).parent().as_block()) - ir, fwd_move = fwd(loop).body()._move(fwd(loop).after()) - fwd = _compose(fwd_move, fwd) - ir, fwd_del = fwd(loop)._delete() - fwd = _compose(fwd_del, fwd) + return ir, fwd - return ir, fwd + return do_check(check, check_mode) # This is same as original FissionAfter, except that # this does not remove loop. We have separate remove_loop # operator for that purpose. -def DoFissionAfterSimple(stmt_cursor, n_lifts, unsafe_disable_checks): - tgt_stmt = stmt_cursor._node - assert isinstance(tgt_stmt, LoopIR.stmt) - assert is_pos_int(n_lifts) - - ir, fwd = stmt_cursor.get_root(), lambda x: x +def DoFissionAfterSimple(stmt_cursor, n_lifts_start, unsafe_disable_checks, check_mode): + def check(checker: Checker): + n_lifts = n_lifts_start + tgt_stmt = stmt_cursor._node + assert isinstance(tgt_stmt, LoopIR.stmt) + assert is_pos_int(n_lifts) - def alloc_check(pre, post): - if not _is_alloc_free(pre, post): - pre_allocs = {s.name for s in pre if isinstance(s, LoopIR.Alloc)} - post_FV = _FV(post) - for nm in pre_allocs: - if nm in post_FV: - raise SchedulingError( - f"Will not fission here, because " - f"doing so will hide the allocation " - f"of {nm} from a later use site." - ) + ir, fwd = stmt_cursor.get_root(), lambda x: x + + def alloc_check(pre, post): + if not _is_alloc_free(pre, post): + pre_allocs = {s.name for s in pre if isinstance(s, LoopIR.Alloc)} + post_FV = _FV(post) + for nm in pre_allocs: + if nm in post_FV: + raise SchedulingError( + f"Will not fission here, because " + f"doing so will hide the allocation " + f"of {nm} from a later use site." + ) - cur_c = stmt_cursor - while n_lifts > 0: - n_lifts -= 1 + cur_c = stmt_cursor + while n_lifts > 0: + n_lifts -= 1 - idx = cur_c.get_index() + 1 - par_c = cur_c.parent() - par_s = par_c._node + idx = cur_c.get_index() + 1 + par_c = cur_c.parent() + par_s = par_c._node - if isinstance(par_s, LoopIR.For): - pre_c = par_c.body()[:idx] - post_c = par_c.body()[idx:] - elif isinstance(par_s, LoopIR.If): - if cur_c._node in par_s.body: + if isinstance(par_s, LoopIR.For): pre_c = par_c.body()[:idx] post_c = par_c.body()[idx:] + elif isinstance(par_s, LoopIR.If): + if cur_c._node in par_s.body: + pre_c = par_c.body()[:idx] + post_c = par_c.body()[idx:] + else: + pre_c = par_c.orelse()[:idx] + post_c = par_c.orelse()[idx:] else: - pre_c = par_c.orelse()[:idx] - post_c = par_c.orelse()[idx:] - else: - raise SchedulingError("Can only lift past a for loop or an if statement") + raise SchedulingError( + "Can only lift past a for loop or an if statement" + ) - pre = [s._node for s in pre_c] - post = [s._node for s in post_c] + pre = [s._node for s in pre_c] + post = [s._node for s in post_c] - if not (pre and post): - continue + if not (pre and post): + continue - alloc_check(pre, post) + alloc_check(pre, post) - if isinstance(par_s, LoopIR.For): - # we must check whether the two parts of the - # fission can commute appropriately - no_loop_var_pre = par_s.iter not in _FV(pre) - if not unsafe_disable_checks: - Check_FissionLoop(ir, par_s, pre, post, no_loop_var_pre) + if isinstance(par_s, LoopIR.For): + # we must check whether the two parts of the + # fission can commute appropriately + no_loop_var_pre = par_s.iter not in _FV(pre) + if not unsafe_disable_checks and checker == "static": + Check_FissionLoop(ir, par_s, pre, post, no_loop_var_pre) - # we can skip the loop iteration if the - # body doesn't depend on the loop - # and the body is idempotent + # we can skip the loop iteration if the + # body doesn't depend on the loop + # and the body is idempotent - def wrapper(body): - return par_s.update(body=body) + def wrapper(body): + return par_s.update(body=body) - ir, fwd_wrap = post_c._wrap(wrapper, "body") - fwd = _compose(fwd_wrap, fwd) + ir, fwd_wrap = post_c._wrap(wrapper, "body") + fwd = _compose(fwd_wrap, fwd) - post_c = fwd_wrap(par_c).body()[-1] - ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) - fwd = _compose(fwd_move, fwd) + post_c = fwd_wrap(par_c).body()[-1] + ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) + fwd = _compose(fwd_move, fwd) - cur_c = fwd_move(fwd_wrap(par_c)) - elif isinstance(par_s, LoopIR.If): - if cur_c._node in par_s.body: + cur_c = fwd_move(fwd_wrap(par_c)) + elif isinstance(par_s, LoopIR.If): + if cur_c._node in par_s.body: - def wrapper(body): - return par_s.update(body=body, orelse=[]) + def wrapper(body): + return par_s.update(body=body, orelse=[]) - ir, fwd_wrap = pre_c._wrap(wrapper, "body") - fwd = _compose(fwd_wrap, fwd) + ir, fwd_wrap = pre_c._wrap(wrapper, "body") + fwd = _compose(fwd_wrap, fwd) - pre_c = fwd_wrap(par_c).body()[0] - ir, fwd_move = pre_c._move(fwd_wrap(par_c).before()) - fwd = _compose(fwd_move, fwd) + pre_c = fwd_wrap(par_c).body()[0] + ir, fwd_move = pre_c._move(fwd_wrap(par_c).before()) + fwd = _compose(fwd_move, fwd) - cur_c = fwd_move(fwd_wrap(par_c)).prev() - else: - assert cur_c._node in par_s.orelse + cur_c = fwd_move(fwd_wrap(par_c)).prev() + else: + assert cur_c._node in par_s.orelse - def wrapper(orelse): - return par_s.update( - body=[LoopIR.Pass(par_s.srcinfo)], orelse=orelse - ) + def wrapper(orelse): + return par_s.update( + body=[LoopIR.Pass(par_s.srcinfo)], orelse=orelse + ) - ir, fwd_wrap = post_c._wrap(wrapper, "orelse") - fwd = _compose(fwd_wrap, fwd) + ir, fwd_wrap = post_c._wrap(wrapper, "orelse") + fwd = _compose(fwd_wrap, fwd) - post_c = fwd_wrap(par_c).orelse()[-1] - ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) - fwd = _compose(fwd_move, fwd) + post_c = fwd_wrap(par_c).orelse()[-1] + ir, fwd_move = post_c._move(fwd_wrap(par_c).after()) + fwd = _compose(fwd_move, fwd) - cur_c = fwd_move(fwd_wrap(par_c)) + cur_c = fwd_move(fwd_wrap(par_c)) - return ir, fwd + if checker == "dynamic": + scope_cursor = stmt_cursor + for _ in range(n_lifts_start): + scope_cursor = scope_cursor.parent() + if n_lifts_start > 0: + fuzz(scope_cursor.as_block(), cur_c.as_block().expand(0, 1)) + return ir, fwd + + return do_check(check, check_mode) # TODO: Deprecate this with the one above @@ -2623,100 +3015,119 @@ def are_allocs_used_after_block(): return ir, fwd -def DoFuseLoop(f_cursor, s_cursor, unsafe_disable_check=False): - proc = f_cursor.get_root() +def DoFuseLoop(f_cursor, s_cursor, check_mode: CheckMode, unsafe_disable_check=False): + def check(checker: Checker): + proc = f_cursor.get_root() - if f_cursor.next() != s_cursor: - raise SchedulingError( - f"expected the two loops to be fused to come one right after the other. However, the statement after the first loop is:\n{f_cursor.next()._node}\n, not the provided second loop:\n {s_cursor._node}" - ) + if f_cursor.next() != s_cursor: + raise SchedulingError( + f"expected the two loops to be fused to come one right after the other. However, the statement after the first loop is:\n{f_cursor.next()._node}\n, not the provided second loop:\n {s_cursor._node}" + ) - # check if the loop bounds are equivalent - loop1 = f_cursor._node - loop2 = s_cursor._node - Check_ExprEqvInContext(proc, loop1.hi, [loop1], loop2.hi, [loop2]) + # check if the loop bounds are equivalent + loop1 = f_cursor._node + loop2 = s_cursor._node + Check_ExprEqvInContext(proc, loop1.hi, [loop1], loop2.hi, [loop2]) - def mk_read(e): - return LoopIR.Read(loop1.iter, [], T.index, loop1.srcinfo) + def mk_read(e): + return LoopIR.Read(loop1.iter, [], T.index, loop1.srcinfo) - ir, fwd = proc, lambda x: x - ir, fwd = _replace_reads( - ir, fwd, s_cursor, loop2.iter, mk_read, only_replace_attrs=False - ) - ir, fwd_move = fwd(s_cursor).body()._move(fwd(f_cursor).body()[-1].after()) - fwd = _compose(fwd_move, fwd) - ir, fwdDel = fwd(s_cursor)._delete() - fwd = _compose(fwdDel, fwd) + ir, fwd = proc, lambda x: x + ir, fwd = _replace_reads( + ir, fwd, s_cursor, loop2.iter, mk_read, only_replace_attrs=False + ) + ir, fwd_move = fwd(s_cursor).body()._move(fwd(f_cursor).body()[-1].after()) + fwd = _compose(fwd_move, fwd) + ir, fwdDel = fwd(s_cursor)._delete() + fwd = _compose(fwdDel, fwd) - if not unsafe_disable_check: - x = LoopIR.Read(loop1.iter, [], T.index, loop1.srcinfo) - y = loop2.iter - body1 = loop1.body - body2 = SubstArgs(loop2.body, {y: x}).result() - loop = fwd(f_cursor)._node - Check_FissionLoop(ir, loop, body1, body2) + if not unsafe_disable_check and checker == "static": + x = LoopIR.Read(loop1.iter, [], T.index, loop1.srcinfo) + y = loop2.iter + body1 = loop1.body + body2 = SubstArgs(loop2.body, {y: x}).result() + loop = fwd(f_cursor)._node + Check_FissionLoop(ir, loop, body1, body2) - return ir, fwd + if checker == "dynamic": + fuzz(f_cursor.as_block().expand(delta_lo=0, delta_hi=1), fwd) + return ir, fwd -def DoFuseIf(f_cursor, s_cursor): - proc = f_cursor.get_root() - if f_cursor.next() != s_cursor: - raise SchedulingError( - "expected the two if statements to be fused to come one right after the other" - ) + return do_check(check, check_mode) - if1 = f_cursor._node - if2 = s_cursor._node - Check_ExprEqvInContext(proc, if1.cond, [if1], if2.cond, [if2]) - cond = if1.cond - body1 = if1.body - body2 = if2.body - orelse1 = if1.orelse - orelse2 = if2.orelse - ifstmt = LoopIR.If(cond, body1 + body2, orelse1 + orelse2, if1.srcinfo) +def DoFuseIf(f_cursor, s_cursor, check_mode: CheckMode): + def check(checker: Checker): + proc = f_cursor.get_root() + if f_cursor.next() != s_cursor: + raise SchedulingError( + "expected the two if statements to be fused to come one right after the other" + ) - ir, fwd = s_cursor.body()._move(f_cursor.body()[-1].after()) - if f_cursor.orelse(): - ir, fwd_move = fwd(s_cursor).orelse()._move(fwd(f_cursor).orelse()[-1].after()) - fwd = _compose(fwd_move, fwd) - else: - ir, fwd_repl = fwd(f_cursor).orelse()._replace(orelse1 + orelse2) - fwd = _compose(fwd_repl, fwd) - ir, fwd_del = fwd(s_cursor)._delete() - fwd = _compose(fwd_del, fwd) - return ir, fwd + if1 = f_cursor._node + if2 = s_cursor._node + if checker == "static": + Check_ExprEqvInContext(proc, if1.cond, [if1], if2.cond, [if2]) + + cond = if1.cond + body1 = if1.body + body2 = if2.body + orelse1 = if1.orelse + orelse2 = if2.orelse + ifstmt = LoopIR.If(cond, body1 + body2, orelse1 + orelse2, if1.srcinfo) + + ir, fwd = s_cursor.body()._move(f_cursor.body()[-1].after()) + if f_cursor.orelse(): + ir, fwd_move = ( + fwd(s_cursor).orelse()._move(fwd(f_cursor).orelse()[-1].after()) + ) + fwd = _compose(fwd_move, fwd) + else: + ir, fwd_repl = fwd(f_cursor).orelse()._replace(orelse1 + orelse2) + fwd = _compose(fwd_repl, fwd) + ir, fwd_del = fwd(s_cursor)._delete() + fwd = _compose(fwd_del, fwd) + if checker == "dynamic": + fuzz(f_cursor.as_block().expand(0, 1), fwd(f_cursor).as_block()) + return ir, fwd + return do_check(check, check_mode) -def DoAddLoop(stmt_cursor, var, hi, guard, unsafe_disable_check): - proc = stmt_cursor.get_root() - s = stmt_cursor._node - if not unsafe_disable_check: - Check_IsIdempotent(proc, [s]) - Check_IsPositiveExpr(proc, [s], hi) +def DoAddLoop(stmt_cursor, var, hi, guard, unsafe_disable_check, check_mode: CheckMode): + def check(checker: Checker): + proc = stmt_cursor.get_root() + s = stmt_cursor._node - sym = Sym(var) + if not unsafe_disable_check and checker == "static": + Check_IsIdempotent(proc, [s]) + Check_IsPositiveExpr(proc, [s], hi) - def wrapper(body): - if guard: - rdsym = LoopIR.Read(sym, [], T.index, s.srcinfo) - zero = LoopIR.Const(0, T.int, s.srcinfo) - cond = LoopIR.BinOp("==", rdsym, zero, T.bool, s.srcinfo) - body = [LoopIR.If(cond, body, [], s.srcinfo)] + sym = Sym(var) - return LoopIR.For( - sym, - LoopIR.Const(0, T.index, s.srcinfo), - hi, - body, - LoopIR.Seq(), - s.srcinfo, - ) + def wrapper(body): + if guard: + rdsym = LoopIR.Read(sym, [], T.index, s.srcinfo) + zero = LoopIR.Const(0, T.int, s.srcinfo) + cond = LoopIR.BinOp("==", rdsym, zero, T.bool, s.srcinfo) + body = [LoopIR.If(cond, body, [], s.srcinfo)] + + return LoopIR.For( + sym, + LoopIR.Const(0, T.index, s.srcinfo), + hi, + body, + LoopIR.Seq(), + s.srcinfo, + ) - ir, fwd = stmt_cursor.as_block()._wrap(wrapper, "body") - return ir, fwd + ir, fwd = stmt_cursor.as_block()._wrap(wrapper, "body") + if checker == "dynamic": + fuzz(stmt_cursor.as_block(), fwd(stmt_cursor).parent().as_block()) + return ir, fwd + + return do_check(check, check_mode) # --------------------------------------------------------------------------- # @@ -2778,10 +3189,59 @@ def err_handler(_, msg): return ir, fwd -def DoDeleteConfig(proc_cursor, config_cursor): - eq_mod_config = Check_DeleteConfigWrite(proc_cursor._node, [config_cursor._node]) - p, fwd = config_cursor._delete() - return p, fwd, eq_mod_config +def DoDeleteConfig(proc_cursor, config_cursor, check_mode: CheckMode): + def check(checker: Checker): + if checker == "static": + eq_mod_config = Check_DeleteConfigWrite( + proc_cursor._node, [config_cursor._node] + ) + else: + eq_mod_config = None + p, fwd = config_cursor._delete() + if checker == "dynamic": + scope = config_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + else: + scope = scope.as_block() + fuzz(scope, fwd(scope)) + return p, fwd, eq_mod_config + + return do_check(check, check_mode) + + +def DoDeleteStmt(stmt_cursor, check_mode: CheckMode): + def check(checker: Checker): + if checker == "static": + assert False, "check must be done with chexo" + else: + scope = stmt_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + else: + scope = scope.as_block() + p, fwd = stmt_cursor._delete() + fuzz(scope, fwd(scope)) + return p, fwd + + return do_check(check, check_mode) + + +def DoInsertStmt(gap_cursor, new_stmt, check_mode: CheckMode): + def check(checker: Checker): + if checker == "static": + assert False, "check must be done with chexo" + else: + scope = gap_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + else: + scope = scope.as_block() + p, fwd = gap_cursor._insert([new_stmt]) + fuzz(scope, fwd(scope)) + return p, fwd + + return do_check(check, check_mode) def DoDeletePass(proc): @@ -3587,57 +4047,78 @@ def map_s(self, sc): raise NotImplementedError(f"bad case {type(s)}") -def DoEliminateIfDeadBranch(if_cursor): - if_stmt = if_cursor._node +def DoEliminateIfDeadBranch(if_cursor, check_mode: CheckMode): + def check(checker: Checker): + if_stmt = if_cursor._node + original_ir = if_cursor.get_root() - assert isinstance(if_stmt, LoopIR.If) + assert isinstance(if_stmt, LoopIR.If) - ir, fwd = if_cursor.get_root(), lambda x: x - - try: - cond_node = LoopIR.Const(True, T.bool, if_stmt.srcinfo) - Check_ExprEqvInContext(ir, if_stmt.cond, [if_stmt], cond_node) - cond = True - except SchedulingError: try: - cond_node = LoopIR.Const(False, T.bool, if_stmt.srcinfo) - Check_ExprEqvInContext(ir, if_stmt.cond, [if_stmt], cond_node) - cond = False + body = if_cursor.body() + ir, fwd = body._move(if_cursor.after()) + ir, fwd_del = fwd(if_cursor)._delete() + fwd = _compose(fwd_del, fwd) + if checker == "static": + cond_node = LoopIR.Const(True, T.bool, if_stmt.srcinfo) + Check_ExprEqvInContext(original_ir, if_stmt.cond, [if_stmt], cond_node) + else: + fuzz(if_cursor.as_block(), fwd(if_cursor.body())) except SchedulingError: - raise SchedulingError("If condition isn't always True or always False") + try: + if checker == "static": + cond_node = LoopIR.Const(False, T.bool, if_stmt.srcinfo) + Check_ExprEqvInContext( + original_ir, if_stmt.cond, [if_stmt], cond_node + ) + else: + body = if_cursor.orelse() + ir, fwd = body._move(if_cursor.after()) + ir, fwd_del = fwd(if_cursor)._delete() + fwd = _compose(fwd_del, fwd) + fuzz(if_cursor.as_block(), fwd(if_cursor.orelse())) + except SchedulingError: + raise SchedulingError("If condition isn't always True or always False") - body = if_cursor.body() if cond else if_cursor.orelse() - ir, fwd = body._move(if_cursor.after()) - ir, fwd_del = fwd(if_cursor)._delete() - fwd = _compose(fwd_del, fwd) + return ir, fwd - return ir, fwd + return do_check(check, "static") # eliminate dead branch broken on chexo -def DoEliminateDeadLoop(loop_cursor): - loop_stmt = loop_cursor._node +def DoEliminateDeadLoop(loop_cursor, check_mode: CheckMode): + def check(checker: Checker): + loop_stmt = loop_cursor._node - assert isinstance(loop_stmt, LoopIR.For) + assert isinstance(loop_stmt, LoopIR.For) - ir, fwd = loop_cursor.get_root(), lambda x: x + if checker == "static": + ir = loop_cursor.get_root() + try: + Check_CompareExprs(ir, [loop_stmt], loop_stmt.lo, ">=", loop_stmt.hi) + except SchedulingError: + raise SchedulingError("Loop condition isn't always False") - try: - Check_CompareExprs(ir, [loop_stmt], loop_stmt.lo, ">=", loop_stmt.hi) - except SchedulingError: - raise SchedulingError("Loop condition isn't always False") + ir, fwd_del = loop_cursor._delete() + if checker == "dynamic": + scope = loop_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + else: + scope = scope.as_block() + fuzz(scope, fwd_del(scope)) - ir, fwd_del = loop_cursor._delete() + return ir, fwd_del - return ir, fwd_del + return do_check(check, check_mode) -def DoEliminateDeadCode(stmt_cursor): +def DoEliminateDeadCode(stmt_cursor, check_mode: CheckMode): stmt = stmt_cursor._node if isinstance(stmt, LoopIR.If): - return DoEliminateIfDeadBranch(stmt_cursor) + return DoEliminateIfDeadBranch(stmt_cursor, check_mode) elif isinstance(stmt, LoopIR.For): - return DoEliminateDeadLoop(stmt_cursor) + return DoEliminateDeadLoop(stmt_cursor, check_mode) else: assert False, f"Unsupported statement type {type(stmt)}" @@ -3652,33 +4133,40 @@ def DoDeleteBuffer(buf_cursor): return buf_cursor._delete() -def DoReuseBuffer(buf_cursor, rep_cursor): - assert isinstance(buf_cursor._node, LoopIR.Alloc) - assert isinstance(rep_cursor._node, LoopIR.Alloc) - assert buf_cursor._node.type == rep_cursor._node.type +def DoReuseBuffer(buf_cursor, rep_cursor, check_mode): + def check(checker: Checker): + assert isinstance(buf_cursor._node, LoopIR.Alloc) + assert isinstance(rep_cursor._node, LoopIR.Alloc) + assert buf_cursor._node.type == rep_cursor._node.type - buf_name = buf_cursor._node.name - buf_dims = len(buf_cursor._node.type.shape()) - rep_name = rep_cursor._node.name - first_assn = True + buf_name = buf_cursor._node.name + buf_dims = len(buf_cursor._node.type.shape()) + rep_name = rep_cursor._node.name + first_assn = True - ir, fwd = rep_cursor._delete() + ir, fwd = rep_cursor._delete() - def mk_read(c): - return {"name": buf_name} + def mk_read(c): + return {"name": buf_name} - def mk_write(c): - nonlocal first_assn - if first_assn: - first_assn = False - Check_IsDeadAfter(buf_cursor.get_root(), [c._node], buf_name, buf_dims) - return {"name": buf_name} + def mk_write(c): + nonlocal first_assn + if first_assn: + first_assn = False + if checker == "static": + Check_IsDeadAfter( + buf_cursor.get_root(), [c._node], buf_name, buf_dims + ) + return {"name": buf_name} - for c in get_rest_of_block(rep_cursor): - ir, fwd = _replace_reads(ir, fwd, c, rep_name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, rep_name, mk_write) + for c in get_rest_of_block(rep_cursor): + ir, fwd = _replace_reads(ir, fwd, c, rep_name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, rep_name, mk_write) + if checker == "dynamic": + fuzz(get_rest_of_block(rep_cursor), fwd(get_rest_of_block(rep_cursor))) + return ir, fwd - return ir, fwd + return do_check(check, check_mode) def index_range_analysis_wrapper(expr: LoopIR.expr) -> IndexRange: @@ -3814,337 +4302,401 @@ def do_e(self, e): super().do_e(e) -def DoFoldBuffer(alloc_cursor, dim_idx, new_size): - alloc_name = alloc_cursor._node.name +def DoFoldBuffer(alloc_cursor, dim_idx, new_size, check_mode: CheckMode): + def check(checker: Checker): + alloc_name = alloc_cursor._node.name - buffer_check = CheckFoldBuffer(alloc_name, dim_idx, new_size) - buffer_check.do_stmts([c._node for c in get_rest_of_block(alloc_cursor)]) + if checker == "static": + buffer_check = CheckFoldBuffer(alloc_name, dim_idx, new_size) + buffer_check.do_stmts([c._node for c in get_rest_of_block(alloc_cursor)]) - size_expr = LoopIR.Const(new_size, T.index, alloc_cursor._node.srcinfo) - ir, fwd = ( - alloc_cursor._child_node("type") - ._child_block("hi")[dim_idx] - ._replace([size_expr]) - ) + size_expr = LoopIR.Const(new_size, T.index, alloc_cursor._node.srcinfo) + ir, fwd = ( + alloc_cursor._child_node("type") + ._child_block("hi")[dim_idx] + ._replace([size_expr]) + ) - def make_index_mod(e): - return LoopIR.BinOp("%", e, size_expr, T.index, e.srcinfo) + def make_index_mod(e): + return LoopIR.BinOp("%", e, size_expr, T.index, e.srcinfo) - def mk_read(c): - rd = c._node - new_idx = rd.idx.copy() - if isinstance(rd, LoopIR.Read): - new_idx[dim_idx] = make_index_mod(rd.idx[dim_idx]) - return {"idx": new_idx} + def mk_read(c): + rd = c._node + new_idx = rd.idx.copy() + if isinstance(rd, LoopIR.Read): + new_idx[dim_idx] = make_index_mod(rd.idx[dim_idx]) + return {"idx": new_idx} - elif isinstance(rd, LoopIR.WindowExpr): - if isinstance(rd.idx[dim_idx], LoopIR.Point): - new_idx[dim_idx] = LoopIR.Point( - make_index_mod(rd.idx[dim_idx].pt), rd.srcinfo - ) + elif isinstance(rd, LoopIR.WindowExpr): + if isinstance(rd.idx[dim_idx], LoopIR.Point): + new_idx[dim_idx] = LoopIR.Point( + make_index_mod(rd.idx[dim_idx].pt), rd.srcinfo + ) + else: + # TODO: see if check_bounds catches the case where lo, hi spans a multiple + # of size, which would break the buffer folding + new_idx[dim_idx] = LoopIR.Interval( + make_index_mod(rd.idx[dim_idx].lo), + make_index_mod(rd.idx[dim_idx].hi), + rd.srcinfo, + ) + + return {"idx": new_idx} else: - # TODO: see if check_bounds catches the case where lo, hi spans a multiple - # of size, which would break the buffer folding - new_idx[dim_idx] = LoopIR.Interval( - make_index_mod(rd.idx[dim_idx].lo), - make_index_mod(rd.idx[dim_idx].hi), - rd.srcinfo, - ) + raise NotImplementedError(f"Did not implement {type(rd)}.") + def mk_write(c): + s = c._node + new_idx = s.idx.copy() + new_idx[dim_idx] = make_index_mod(s.idx[dim_idx]) return {"idx": new_idx} - else: - raise NotImplementedError(f"Did not implement {type(rd)}.") - def mk_write(c): - s = c._node - new_idx = s.idx.copy() - new_idx[dim_idx] = make_index_mod(s.idx[dim_idx]) - return {"idx": new_idx} + for c in get_rest_of_block(alloc_cursor): + ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) + ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) - for c in get_rest_of_block(alloc_cursor): - ir, fwd = _replace_reads(ir, fwd, c, alloc_name, mk_read) - ir, fwd = _replace_writes(ir, fwd, c, alloc_name, mk_write) + new_alloc_cursor = fwd(alloc_cursor) + after_alloc = [c._node for c in get_rest_of_block(new_alloc_cursor)] - alloc_cursor = fwd(alloc_cursor) - after_alloc = [c._node for c in get_rest_of_block(alloc_cursor)] - Check_Bounds(ir, alloc_cursor._node, after_alloc) + if checker == "static": + Check_Bounds(ir, new_alloc_cursor._node, after_alloc) + else: + fuzz(get_rest_of_block(alloc_cursor), fwd(get_rest_of_block(alloc_cursor))) + return ir, fwd - return ir, fwd + return do_check(check, check_mode) -def DoStageMem(block_cursor, buf_name, w_exprs, new_name, use_accum_zero=False): +def DoStageMem( + block_cursor, buf_name_str, w_exprs, new_name, check_mode, use_accum_zero=False +): new_name = Sym(new_name) - def get_typ_mem(): - syms_env = extract_env(block_cursor[0]) - for name, typ, mem in syms_env: - if str(name) == buf_name: - return name, typ, mem - assert False, "Must find the symbol in env" - - buf_name, buf_typ, mem = get_typ_mem() - buf_typ = buf_typ if not isinstance(buf_typ, T.Window) else buf_typ.as_tensor + def check(checker: Checker): + def get_typ_mem(): + syms_env = extract_env(block_cursor[0]) + for name, typ, mem in syms_env: + if str(name) == buf_name_str: + return name, typ, mem + assert False, "Must find the symbol in env" - if len(w_exprs) != len(buf_typ.shape()): - raise SchedulingError( - f"expected windowing of '{buf_name}' " - f"to have {len(buf_typ.shape())} indices, " - f"but only got {len(w_exprs)}" - ) + buf_name, buf_typ, mem = get_typ_mem() + buf_typ = buf_typ if not isinstance(buf_typ, T.Window) else buf_typ.as_tensor - shape = [ - LoopIR.BinOp("-", w[1], w[0], T.index, w[0].srcinfo) - for w in w_exprs - if isinstance(w, tuple) - ] - if all(isinstance(w, LoopIR.expr) for w in w_exprs): - new_typ = buf_typ.basetype() - else: - new_typ = T.Tensor(shape, False, buf_typ.basetype()) + if len(w_exprs) != len(buf_typ.shape()): + raise SchedulingError( + f"expected windowing of '{buf_name}' " + f"to have {len(buf_typ.shape())} indices, " + f"but only got {len(w_exprs)}" + ) - def rewrite_idx(idx): - assert len(idx) == len(w_exprs) - return [ - LoopIR.BinOp("-", i, w[0], T.index, i.srcinfo) - for i, w in zip(idx, w_exprs) + shape = [ + LoopIR.BinOp("-", w[1], w[0], T.index, w[0].srcinfo) + for w in w_exprs if isinstance(w, tuple) ] + if all(isinstance(w, LoopIR.expr) for w in w_exprs): + new_typ = buf_typ.basetype() + else: + new_typ = T.Tensor(shape, False, buf_typ.basetype()) + + def rewrite_idx(idx): + assert len(idx) == len(w_exprs) + return [ + LoopIR.BinOp("-", i, w[0], T.index, i.srcinfo) + for i, w in zip(idx, w_exprs) + if isinstance(w, tuple) + ] - def rewrite_win(w_idx): - assert len(w_idx) == len(w_exprs) - - def off_w(w, off): - if isinstance(w, LoopIR.Interval): - lo = LoopIR.BinOp("-", w.lo, off, T.index, w.srcinfo) - hi = LoopIR.BinOp("-", w.hi, off, T.index, w.srcinfo) - return LoopIR.Interval(lo, hi, w.srcinfo) - else: - assert isinstance(w, LoopIR.Point) - pt = LoopIR.BinOp("-", w.pt, off, T.index, w.srcinfo) - return LoopIR.Point(pt, w.srcinfo) - - w_los = [w_e[0] if isinstance(w_e, tuple) else w_e for w_e in w_exprs] - - return [off_w(w_i, w_e) for w_i, w_e in zip(w_idx, w_los)] + def rewrite_win(w_idx): + assert len(w_idx) == len(w_exprs) - ir = block_cursor.get_root() - block = [s._node for s in block_cursor] - if use_accum_zero: - n_dims = len(buf_typ.shape()) - Check_BufferReduceOnly( - ir, - block, - buf_name, - n_dims, - ) + def off_w(w, off): + if isinstance(w, LoopIR.Interval): + lo = LoopIR.BinOp("-", w.lo, off, T.index, w.srcinfo) + hi = LoopIR.BinOp("-", w.hi, off, T.index, w.srcinfo) + return LoopIR.Interval(lo, hi, w.srcinfo) + else: + assert isinstance(w, LoopIR.Point) + pt = LoopIR.BinOp("-", w.pt, off, T.index, w.srcinfo) + return LoopIR.Point(pt, w.srcinfo) - n_dims = len(buf_typ.shape()) - basetyp = new_typ.basetype() if isinstance(new_typ, T.Tensor) else new_typ - srcinfo = block[0].srcinfo + w_los = [w_e[0] if isinstance(w_e, tuple) else w_e for w_e in w_exprs] - new_alloc = [LoopIR.Alloc(new_name, new_typ, mem, srcinfo)] - ir, fwd = block_cursor[0].before()._insert(new_alloc) + return [off_w(w_i, w_e) for w_i, w_e in zip(w_idx, w_los)] - def get_inner_stmt(loop_nest_c): - node = loop_nest_c._node - if not isinstance(node, LoopIR.For): - return loop_nest_c - return get_inner_stmt(loop_nest_c.body()[0]) + ir = block_cursor.get_root() + block = [s._node for s in block_cursor] + if use_accum_zero: + n_dims = len(buf_typ.shape()) + if checker == "static": + Check_BufferReduceOnly( + ir, + block, + buf_name, + n_dims, + ) - # Insert guards to ensure load/store stages don't access out of bounds - def insert_safety_guards(ir, fwd, ctxt_stmt_c, access, buf_typ): - def check_cond(cond): - ctxt_stmt = ctxt_stmt_c._node - true_node = LoopIR.Const(True, T.bool, ctxt_stmt.srcinfo) - try: - Check_ExprEqvInContext(ir, cond, [ctxt_stmt], true_node) - return True - except SchedulingError: - return False + n_dims = len(buf_typ.shape()) + basetyp = new_typ.basetype() if isinstance(new_typ, T.Tensor) else new_typ + srcinfo = block[0].srcinfo + + new_alloc = [LoopIR.Alloc(new_name, new_typ, mem, srcinfo)] + ir, fwd = block_cursor[0].before()._insert(new_alloc) + + def get_inner_stmt(loop_nest_c): + node = loop_nest_c._node + if not isinstance(node, LoopIR.For): + return loop_nest_c + return get_inner_stmt(loop_nest_c.body()[0]) + + if checker == "dynamic": + coverage_summary = fuzz_single_scope( + block_cursor, + StageMemArgs(buf_name, StagedWindowExpr(w_exprs), block_cursor), + ) + else: + coverage_summary = None + + # Insert guards to ensure load/store stages don't access out of bounds + def insert_safety_guards(ir, fwd, ctxt_stmt_c, access, buf_typ): + def check_cond(cond): + ctxt_stmt = ctxt_stmt_c._node + true_node = LoopIR.Const(True, T.bool, ctxt_stmt.srcinfo) + try: + Check_ExprEqvInContext(ir, cond, [ctxt_stmt], true_node) + return True + except SchedulingError: + return False + + # Get a list of lower/upper bound on the index accesses + const_0 = LoopIR.Const(0, T.int, access.srcinfo) + conds = [] + if coverage_summary is None: + for i in zip(access.idx, buf_typ.shape()): + lower_bound_cond = LoopIR.BinOp( + "<=", const_0, i[0], T.bool, access.srcinfo + ) + if not check_cond(lower_bound_cond): + conds.append(lower_bound_cond) + upper_bound_cond = LoopIR.BinOp( + "<", i[0], i[1], T.bool, access.srcinfo + ) + if not check_cond(upper_bound_cond): + conds.append(upper_bound_cond) + else: + for ( + access_idx, + buf_upper_bound, + guard_upper_bound, + guard_lower_bound, + ) in zip( + access.idx, + buf_typ.shape(), + coverage_summary.stage_mem_result.dim_needs_upper_bound_guard, + coverage_summary.stage_mem_result.dim_needs_lower_bound_guard, + ): + lower_bound_cond = LoopIR.BinOp( + "<=", const_0, access_idx, T.bool, access.srcinfo + ) + if guard_lower_bound: + conds.append(lower_bound_cond) + upper_bound_cond = LoopIR.BinOp( + "<", access_idx, buf_upper_bound, T.bool, access.srcinfo + ) + if guard_upper_bound: + conds.append(upper_bound_cond) + + if len(conds) == 0: + return ir, fwd + + # Construct the condition + cond = conds[0] + for c in conds[1:]: + cond = LoopIR.BinOp("and", cond, c, T.bool, cond.srcinfo) + + # Construct the If statement and wrap the context statement + def guard_wrapper(body): + return LoopIR.If(cond, body, [], srcinfo) + + # You want to forward `ctxt_stmt_c` instead of relying on passing + # the forwarded version. However, in all the current callees, the + # statement would have been just constructed and if you try to forward + # you get an error. + ir, fwd_wrap = ctxt_stmt_c.as_block()._wrap(guard_wrapper, "body") + fwd = _compose(fwd_wrap, fwd) - # Get a list of lower/upper bound on the index accesses - const_0 = LoopIR.Const(0, T.int, access.srcinfo) - conds = [] - for i in zip(access.idx, buf_typ.shape()): - lower_bound_cond = LoopIR.BinOp("<=", const_0, i[0], T.bool, access.srcinfo) - if not check_cond(lower_bound_cond): - conds.append(lower_bound_cond) - upper_bound_cond = LoopIR.BinOp("<", i[0], i[1], T.bool, access.srcinfo) - if not check_cond(upper_bound_cond): - conds.append(upper_bound_cond) - - if len(conds) == 0: return ir, fwd - # Construct the condition - cond = conds[0] - for c in conds[1:]: - cond = LoopIR.BinOp("and", cond, c, T.bool, cond.srcinfo) + def idx_contained_by_window(idx, block_cursor): + """ + Returns True if idx always lies in staged window range. + Returns False if idx never lies in staged window range. + Otherwise, will raise a SchedulingError. + """ + if coverage_summary is None: + p = idx.get_root() + return Check_Access_In_Window(p, idx, w_exprs, block_cursor) + else: + return ( + idx.get_path() + in coverage_summary.stage_mem_result.overlapping_accesses + ) - # Construct the If statement and wrap the context statement - def guard_wrapper(body): - return LoopIR.If(cond, body, [], srcinfo) + actualR = actualW = False + WShadow = False + # Conservatively, shadowing logic only works for single element staging windows. + w_is_pt = all(not isinstance(w, tuple) for w in w_exprs) - # You want to forward `ctxt_stmt_c` instead of relying on passing - # the forwarded version. However, in all the current callees, the - # statement would have been just constructed and if you try to forward - # you get an error. - ir, fwd_wrap = ctxt_stmt_c.as_block()._wrap(guard_wrapper, "body") - fwd = _compose(fwd_wrap, fwd) + def mk_read(c, block_cursor): + nonlocal actualR + rd = c._node - return ir, fwd + if isinstance(rd, LoopIR.Read): + if idx_contained_by_window(c, block_cursor): + _idx = rewrite_idx(rd.idx) + actualR = True + return {"name": new_name, "idx": _idx} + elif isinstance(rd, LoopIR.WindowExpr): + if any( + isinstance(w, LoopIR.Interval) and not isinstance(w_e, tuple) + for w, w_e in zip(rd.idx, w_exprs) + ): + raise SchedulingError( + f"Existing WindowExpr {rd} has a widnowed dimension which is not windowed in the new staged window." + ) - def idx_contained_by_window(idx, block_cursor): - """ - Returns True if idx always lies in staged window range. - Returns False if idx never lies in staged window range. - Otherwise, will raise a SchedulingError. - """ - p = idx.get_root() - return Check_Access_In_Window(p, idx, w_exprs, block_cursor) + if idx_contained_by_window(c, block_cursor): + _idx = rewrite_win(rd.idx) + _typ = T.Window(new_typ, rd.type.as_tensor, new_name, _idx) + actualR = True + return {"name": new_name, "idx": _idx, "type": _typ} - actualR = actualW = False - WShadow = False - # Conservatively, shadowing logic only works for single element staging windows. - w_is_pt = all(not isinstance(w, tuple) for w in w_exprs) + def mk_write(c, block_cursor): + nonlocal actualR + nonlocal actualW + nonlocal WShadow + s = c._node + if isinstance(s, (LoopIR.Assign, LoopIR.Reduce)): + if idx_contained_by_window(c, block_cursor): + actualW = True + if isinstance(s, LoopIR.Reduce): + actualR = True + if not actualR and w_is_pt: + WShadow = True + return {"name": new_name, "idx": rewrite_idx(s.idx)} + + for c in block_cursor: + ir, fwd = _replace_reads( + ir, fwd, c, buf_name, partial(mk_read, block_cursor=fwd(block_cursor)) + ) - def mk_read(c, block_cursor): - nonlocal actualR - rd = c._node + ir, fwd = _replace_writes( + ir, fwd, c, buf_name, partial(mk_write, block_cursor=fwd(block_cursor)) + ) - if isinstance(rd, LoopIR.Read): - if idx_contained_by_window(c, block_cursor): - _idx = rewrite_idx(rd.idx) - actualR = True - return {"name": new_name, "idx": _idx} - elif isinstance(rd, LoopIR.WindowExpr): - if any( - isinstance(w, LoopIR.Interval) and not isinstance(w_e, tuple) - for w, w_e in zip(rd.idx, w_exprs) - ): - raise SchedulingError( - f"Existing WindowExpr {rd} has a widnowed dimension which is not windowed in the new staged window." + if actualR and not WShadow: + load_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)] + load_widx = [LoopIR.Read(s, [], T.index, srcinfo) for s in load_iter] + if use_accum_zero: + load_rhs = LoopIR.Const(0.0, basetyp, srcinfo) + else: + cp_load_widx = load_widx.copy() + load_ridx = [] + for w in w_exprs: + if isinstance(w, tuple): + load_ridx.append( + LoopIR.BinOp( + "+", cp_load_widx.pop(0), w[0], T.index, srcinfo + ) + ) + else: + load_ridx.append(w) + load_rhs = LoopIR.Read(buf_name, load_ridx, basetyp, srcinfo) + + load_nest = [LoopIR.Assign(new_name, basetyp, load_widx, load_rhs, srcinfo)] + + for i, n in reversed(list(zip(load_iter, shape))): + loop = LoopIR.For( + i, + LoopIR.Const(0, T.index, srcinfo), + n, + load_nest, + LoopIR.Seq(), + srcinfo, ) + load_nest = [loop] - if idx_contained_by_window(c, block_cursor): - _idx = rewrite_win(rd.idx) - _typ = T.Window(new_typ, rd.type.as_tensor, new_name, _idx) - actualR = True - return {"name": new_name, "idx": _idx, "type": _typ} - - def mk_write(c, block_cursor): - nonlocal actualR - nonlocal actualW - nonlocal WShadow - s = c._node - if isinstance(s, (LoopIR.Assign, LoopIR.Reduce)): - if idx_contained_by_window(c, block_cursor): - actualW = True - if isinstance(s, LoopIR.Reduce): - actualR = True - if not actualR and w_is_pt: - WShadow = True - return {"name": new_name, "idx": rewrite_idx(s.idx)} - - for c in block_cursor: - ir, fwd = _replace_reads( - ir, fwd, c, buf_name, partial(mk_read, block_cursor=fwd(block_cursor)) - ) + ir, fwd_ins = fwd(block_cursor[0]).before()._insert(load_nest) + fwd = _compose(fwd_ins, fwd) - ir, fwd = _replace_writes( - ir, fwd, c, buf_name, partial(mk_write, block_cursor=fwd(block_cursor)) - ) + if not use_accum_zero: + load_nest_c = fwd(block_cursor[0]).prev() + ir, fwd = insert_safety_guards( + ir, fwd, get_inner_stmt(load_nest_c), load_rhs, buf_typ + ) - if actualR and not WShadow: - load_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)] - load_widx = [LoopIR.Read(s, [], T.index, srcinfo) for s in load_iter] - if use_accum_zero: - load_rhs = LoopIR.Const(0.0, basetyp, srcinfo) - else: - cp_load_widx = load_widx.copy() - load_ridx = [] + if actualW: + store_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)] + store_ridx = [LoopIR.Read(s, [], T.index, srcinfo) for s in store_iter] + cp_store_ridx = store_ridx.copy() + store_widx = [] for w in w_exprs: if isinstance(w, tuple): - load_ridx.append( - LoopIR.BinOp("+", cp_load_widx.pop(0), w[0], T.index, srcinfo) + store_widx.append( + LoopIR.BinOp("+", cp_store_ridx.pop(0), w[0], T.index, srcinfo) ) else: - load_ridx.append(w) - load_rhs = LoopIR.Read(buf_name, load_ridx, basetyp, srcinfo) - - load_nest = [LoopIR.Assign(new_name, basetyp, load_widx, load_rhs, srcinfo)] - - for i, n in reversed(list(zip(load_iter, shape))): - loop = LoopIR.For( - i, - LoopIR.Const(0, T.index, srcinfo), - n, - load_nest, - LoopIR.Seq(), - srcinfo, - ) - load_nest = [loop] + store_widx.append(w) + + store_rhs = LoopIR.Read(new_name, store_ridx, basetyp, srcinfo) + store_stmt = LoopIR.Reduce if use_accum_zero else LoopIR.Assign + store_nest = [store_stmt(buf_name, basetyp, store_widx, store_rhs, srcinfo)] + + for i, n in reversed(list(zip(store_iter, shape))): + loop = LoopIR.For( + i, + LoopIR.Const(0, T.index, srcinfo), + n, + store_nest, + LoopIR.Seq(), + srcinfo, + ) + store_nest = [loop] - ir, fwd_ins = fwd(block_cursor[0]).before()._insert(load_nest) - fwd = _compose(fwd_ins, fwd) + ir, fwd_ins = fwd(block_cursor[-1]).after()._insert(store_nest) + fwd = _compose(fwd_ins, fwd) - if not use_accum_zero: - load_nest_c = fwd(block_cursor[0]).prev() + store_nest_c = fwd(block_cursor[-1]).next() + store_stmt_c = get_inner_stmt(store_nest_c) ir, fwd = insert_safety_guards( - ir, fwd, get_inner_stmt(load_nest_c), load_rhs, buf_typ + ir, fwd, store_stmt_c, store_stmt_c._node, buf_typ ) - if actualW: - store_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)] - store_ridx = [LoopIR.Read(s, [], T.index, srcinfo) for s in store_iter] - cp_store_ridx = store_ridx.copy() - store_widx = [] - for w in w_exprs: - if isinstance(w, tuple): - store_widx.append( - LoopIR.BinOp("+", cp_store_ridx.pop(0), w[0], T.index, srcinfo) - ) - else: - store_widx.append(w) - - store_rhs = LoopIR.Read(new_name, store_ridx, basetyp, srcinfo) - store_stmt = LoopIR.Reduce if use_accum_zero else LoopIR.Assign - store_nest = [store_stmt(buf_name, basetyp, store_widx, store_rhs, srcinfo)] - - for i, n in reversed(list(zip(store_iter, shape))): - loop = LoopIR.For( - i, - LoopIR.Const(0, T.index, srcinfo), - n, - store_nest, - LoopIR.Seq(), - srcinfo, + # new alloc, load_nest + new_body + store_nest + new_block_c = fwd(block_cursor[0]).as_block().expand(0, len(block_cursor) - 1) + if actualR and not WShadow: + new_block_c = new_block_c.expand(1, 0) + if actualW: + new_block_c = new_block_c.expand(0, 1) + if not actualR and not actualW: + raise SchedulingError( + f"Cannot stage '{buf_name}' with the given window shape. Wrong window shape, or '{buf_name}' not accessed in the given scope?" ) - store_nest = [loop] - - ir, fwd_ins = fwd(block_cursor[-1]).after()._insert(store_nest) - fwd = _compose(fwd_ins, fwd) - store_nest_c = fwd(block_cursor[-1]).next() - store_stmt_c = get_inner_stmt(store_nest_c) - ir, fwd = insert_safety_guards( - ir, fwd, store_stmt_c, store_stmt_c._node, buf_typ - ) - - # new alloc, load_nest + new_body + store_nest - new_block_c = fwd(block_cursor[0]).as_block().expand(0, len(block_cursor) - 1) - if actualR and not WShadow: - new_block_c = new_block_c.expand(1, 0) - if actualW: - new_block_c = new_block_c.expand(0, 1) - if not actualR and not actualW: - raise SchedulingError( - f"Cannot stage '{buf_name}' with the given window shape. Wrong window shape, or '{buf_name}' not accessed in the given scope?" - ) + if checker == "static": + Check_Bounds(ir, new_alloc[0], [c._node for c in new_block_c]) + else: + scope = block_cursor.parent() + if scope.depth() == 0: + scope = scope._child_block("body") + else: + scope = scope.as_block() + fuzz(scope, fwd(scope)) - Check_Bounds(ir, new_alloc[0], [c._node for c in new_block_c]) + return ir, fwd - return ir, fwd + return do_check(check, check_mode) def DoUnrollBuffer(alloc_cursor, dim): diff --git a/src/exo/rewrite/chexo/LoopIR_transpiler.py b/src/exo/rewrite/chexo/LoopIR_transpiler.py new file mode 100644 index 000000000..a70a6704c --- /dev/null +++ b/src/exo/rewrite/chexo/LoopIR_transpiler.py @@ -0,0 +1,1489 @@ +from functools import reduce +from itertools import chain +from string import Template +from typing import Any, Callable, Generator, Iterable, Optional, Union + +from ...core.configs import Config + +from ...core.prelude import Sym +from ...core.LoopIR import LoopIR, T +from .coverage import ( + CoverageSkeleton, + CoverageSkeletonNode, + CoverageSkeletonBranch, + FailureCondition, + IndexedFiller, + MemoryAccess, + MemoryAccessPair, + ParallelAccess, + ParallelAccessPair, + StagingBoundCheck, + SymbolicPoint, + SymbolicSlice, + StagingOverlap, + SymbolicWindowIndex, +) +from ...core.internal_cursors import Block, Cursor, Node, NodePath +from .constraint_solver import ( + TRUE_CONSTRAINT, + Constraint, + ConstraintMaker, + DisjointConstraint, + Expression, +) +from dataclasses import dataclass, field + +import numpy as np + + +@dataclass +class BaseType: + loopir_type: Any + dtype: type + javascript_array_type: str + + +base_types = ( + BaseType(T.F16, np.float16, "Float16Array"), + BaseType(T.F32, np.float32, "Float32Array"), + BaseType(T.F64, np.float64, "Float64Array"), + BaseType(T.INT8, np.int8, "Int8Array"), + BaseType(T.UINT8, np.uint8, "Uint8Array"), + BaseType(T.UINT16, np.uint16, "Uint16Array"), + BaseType(T.INT32, np.int32, "Int32Array"), + BaseType(T.Num, np.float64, "Float64Array"), +) + + +def lookup_loopir_type(loopir_type: Any) -> BaseType: + return next( + base_type + for base_type in base_types + if isinstance(loopir_type, base_type.loopir_type) + ) + + +def lookup_dtype(dtype: type) -> BaseType: + return next(base_type for base_type in base_types if base_type.dtype == dtype) + + +@dataclass +class Constant: + name: str + + +@dataclass +class Reference: + name: str + is_config: bool + + +@dataclass +class Point: + index: str + + +@dataclass +class Slice: + lower_bound: str + upper_bound: str + + +WindowIndex = Union[Point, Slice] + + +@dataclass +class Dimension: + size: str + stride: str + window_idx: WindowIndex + + +@dataclass +class Tensor: + name: Sym + dims: tuple[Dimension, ...] + + +ExoValue = Union[Constant, Reference, Tensor] + + +CONTEXT_OBJECT_NAME = "ctxt" +INITIAL_DYNAMIC_SIZE = 16 + + +@dataclass +class SymbolicTensor: + name: Sym + dims: tuple[SymbolicWindowIndex, ...] + resize_placeholder: Optional[int] + + +class AliasingTracker: + def __init__(self, parent_state: "CoverageState"): + self.writes: dict[Sym, list[MemoryAccess]] = {} + self.reads: dict[Sym, list[MemoryAccess]] = {} + self.parent_state = parent_state + + def access_tensor( + self, + js_tensor: Tensor, + js_idxs: tuple[str, ...], + symbolic_idxs: tuple[Expression], + access_placeholder: int, + _access_cursor: Node, + cov_node: CoverageSkeletonNode, + is_write: bool, + ): + js_access = "+".join( + f"Math.imul({js_idx},{js_dim.stride})" + for js_idx, js_dim in zip(js_idxs, js_tensor.dims) + ) + access_sym = Sym("access") + mark_stmt = f"{repr(access_sym)}[{js_access}]=1;" + base_size = f"Math.imul({js_tensor.dims[0].size},{js_tensor.dims[0].stride})" + resize_placeholder = self.parent_state.symbolic_tensors[ + js_tensor.name + ].resize_placeholder + if resize_placeholder is None: + decl_stmt = f"let {repr(access_sym)}=new Uint8Array({base_size});" + fillers = ( + IndexedFiller(access_placeholder, mark_stmt), + IndexedFiller(self.parent_state.cov_placeholder, decl_stmt), + ) + else: + temp_sym = Sym("temp") + decl_stmt = ( + f"let {repr(access_sym)}=new Uint8Array({INITIAL_DYNAMIC_SIZE});" + ) + resize_stmt = f"if({base_size}>{repr(access_sym)}.length){{let {repr(temp_sym)}=new Uint8Array(Math.max(2*{repr(access_sym)}.length,{base_size}));for(let i=0;i<{repr(access_sym)}.length;i++){repr(temp_sym)}[i]={repr(access_sym)}[i];{repr(access_sym)}={repr(temp_sym)}}};" + fillers = ( + IndexedFiller(access_placeholder, mark_stmt), + IndexedFiller(self.parent_state.cov_placeholder, decl_stmt), + IndexedFiller(resize_placeholder, resize_stmt), + ) + + dest = self.writes if is_write else self.reads + if js_tensor.name not in dest: + dest[js_tensor.name] = [] + dest[js_tensor.name].append( + MemoryAccess(access_sym, cov_node, symbolic_idxs, fillers) + ) + + def access_scalar( + self, + name: Sym, + access_placeholder: int, + _access_cursor: Node, + cov_node: CoverageSkeletonNode, + is_write: bool, + ): + access_sym = Sym("access") + mark_stmt = f"{repr(access_sym)}=true;" + decl_stmt = f"let {repr(access_sym)}=false;" + + dest = self.writes if is_write else self.reads + if name not in dest: + dest[name] = [] + dest[name].append( + MemoryAccess( + access_sym, + cov_node, + (), + ( + IndexedFiller(access_placeholder, mark_stmt), + IndexedFiller(self.parent_state.cov_placeholder, decl_stmt), + ), + ) + ) + + def make_aliasing_accesses(self) -> tuple[MemoryAccessPair, ...]: + aliasable_accesses: list[MemoryAccessPair] = [] + for sym, write_indices in self.writes.items(): + read_indices = self.reads[sym] if sym in self.reads else [] + for i, index1 in enumerate(write_indices): + for index2 in chain(write_indices[i + 1 :], read_indices): + aliasable_accesses.append(MemoryAccessPair(index1, index2)) + return tuple(aliasable_accesses) + + +class FailureTracker: + def __init__(self, scope: Block, parent_state: "CoverageState"): + self.scope = scope + self.in_scope = False + self.call_depth = 0 + self.failure_conditions: list[FailureCondition] = [] + self.parent_state = parent_state + + def enter_stmt(self, stmt_cursor: Node): + if stmt_cursor in self.scope: + self.in_scope = True + + def exit_stmt(self, stmt_cursor): + if stmt_cursor in self.scope: + self.in_scope = False + + def enter_proc_body(self): + self.call_depth += 1 + + def exit_proc_body(self): + self.call_depth -= 1 + + def add_assertion( + self, asserted_cond: DisjointConstraint, js_cond: str, placeholder: int + ): + if self.in_scope and self.call_depth == 1: + fail_sym = Sym("fail") + self.failure_conditions.append( + FailureCondition( + fail_sym, + asserted_cond.invert(), + self.parent_state.current_node, + ( + IndexedFiller( + self.parent_state.cov_placeholder, + f"let {repr(fail_sym)}=false;", + ), + IndexedFiller( + placeholder, f"if(!({js_cond})){repr(fail_sym)}=true;" + ), + ), + ) + ) + + def make_failures(self) -> tuple[FailureCondition, ...]: + return tuple(self.failure_conditions) + + +@dataclass +class SymbolicWindow: + name: Sym + index: tuple[SymbolicWindowIndex, ...] + + +@dataclass +class StagedWindowExpr: + indices: tuple[Union[tuple[LoopIR.expr, LoopIR.expr], LoopIR.expr], ...] + + +@dataclass +class StageMemArgs: + buffer_sym: Sym + staged_window_expr: StagedWindowExpr + scope: Block + + +class StageMemTracker: + def __init__(self, args: StageMemArgs, parent_state: "CoverageState"): + self.scope: Block = args.scope + self.buffer_sym = args.buffer_sym + self.staged_window_expr = args.staged_window_expr + self.staged_window: Optional[tuple[SymbolicTensor, Tensor]] = None + self.enabled: bool = False + self.overlaps: list[StagingOverlap] = [] + self.bound_checks: list[StagingBoundCheck] = [] + self.parent_state: "CoverageState" = parent_state + + def enter_stmt(self, stmt_node: Node): + if stmt_node in self.scope: + self.enabled = True + + if self.staged_window is None: + staged_win_sym = Sym("win") + js_parent = self.parent_state.parent_transpiler._lookup_sym( + self.buffer_sym + ) + js_staged = self.parent_state.parent_transpiler._transpile_window( + staged_win_sym, + self.buffer_sym, + Cursor.create(self.staged_window_expr)._child_block("indices"), + ) + assert isinstance(js_parent, Tensor) + symbolic_parent = self.parent_state.symbolic_tensors[self.buffer_sym] + symbolic_staged = self.parent_state.symbolic_tensors[staged_win_sym] + stage_placeholder = ( + self.parent_state.parent_transpiler._make_placeholder() + ) + for ( + symbolic_parent_dim, + symbolic_staged_dim, + js_parent_dim, + js_staged_dim, + ) in zip( + symbolic_parent.dims, + symbolic_staged.dims, + (dim.window_idx for dim in js_parent.dims), + (dim.window_idx for dim in js_staged.dims), + ): + if isinstance(symbolic_parent_dim, SymbolicSlice): + assert isinstance(js_parent_dim, Slice) + upper_bound_violation_sym = Sym("ub") + lower_bound_violation_sym = Sym("lb") + js_upper_bound_violation_cond = ( + f"({js_staged_dim.index}>={js_parent_dim.upper_bound})" + if isinstance(js_staged_dim, Point) + else f"({js_staged_dim.upper_bound})>{js_parent_dim.upper_bound})" + ) + js_lower_bound_violation_cond = ( + f"({js_parent_dim.lower_bound}<={js_staged_dim.index})" + if isinstance(js_staged_dim, Point) + else f"({js_parent_dim.lower_bound}<={js_staged_dim.lower_bound})" + ) + self.bound_checks.append( + StagingBoundCheck( + upper_bound_violation_sym, + lower_bound_violation_sym, + symbolic_staged_dim, + symbolic_parent_dim, + self.parent_state.current_node, + ( + IndexedFiller( + self.parent_state.cov_placeholder, + "".join( + ( + f"let {repr(upper_bound_violation_sym)}=false;", + f"let {repr(lower_bound_violation_sym)}=false;", + ) + ), + ), + IndexedFiller( + stage_placeholder, + "".join( + ( + f"if({js_upper_bound_violation_cond}){{let {repr(upper_bound_violation_sym)}=true;}}", + f"if({js_lower_bound_violation_cond}){{let {repr(lower_bound_violation_sym)}=true;}}", + ) + ), + ), + ), + ) + ) + self.staged_window = ( + symbolic_staged, + js_staged, + ) + + def exit_stmt(self, stmt_cursor: Node): + if stmt_cursor in self.scope: + self.enabled = False + + def access_tensor( + self, + js_tensor: Tensor, + js_idxs: tuple[str, ...], + symbolic_idxs: tuple[Expression], + access_placeholder: int, + access_cursor: Node, + cov_node: CoverageSkeletonNode, + _is_write: bool, + ): + if ( + self.staged_window is not None + and self.enabled + and self.staged_window[1].name == js_tensor.name + ): + symbolic_staged_window, js_staged_window = self.staged_window + overlap_sym = Sym("overlap") + disjoint_sym = Sym("disjoint") + access_window = tuple(SymbolicPoint(idx) for idx in symbolic_idxs) + js_overlap_cond = "&&".join( + f"(({js_staged_slice.lower_bound})<=({js_idx}))&&(({js_staged_slice.upper_bound})>({js_idx}))" + for js_idx, js_staged_slice in zip( + js_idxs, + ( + dim.window_idx + for dim in js_staged_window.dims + if isinstance(dim.window_idx, Slice) + ), + ) + ) + self.overlaps.append( + StagingOverlap( + overlap_sym, + disjoint_sym, + symbolic_staged_window.dims, + access_window, + cov_node, + access_cursor.get_path(), + ( + IndexedFiller( + access_placeholder, + f"if({js_overlap_cond}){{{repr(overlap_sym)}=true}}else{{{repr(disjoint_sym)}=true}}", + ), + IndexedFiller( + self.parent_state.cov_placeholder, + f"let {repr(overlap_sym)}=false;let {repr(disjoint_sym)}=false;", + ), + ), + ) + ) + + def make_staging_overlaps(self) -> tuple[StagingOverlap, ...]: + return tuple(self.overlaps) + + def make_staging_bound_checks(self) -> tuple[StagingBoundCheck, ...]: + return tuple(self.bound_checks) + + +@dataclass +class ParallelScope: + iter_sym: Sym + loop_entrance_placeholder: int + writes: dict[Sym, list[ParallelAccess]] = field(default_factory=lambda: {}) + reads: dict[Sym, list[ParallelAccess]] = field(default_factory=lambda: {}) + access_set_syms: dict[Sym, Sym] = field(default_factory=lambda: {}) + + +class ParallelAccessTracker: + def __init__(self, parent_state: "CoverageState"): + self.parallel_scopes: list[ParallelScope] = [] + self.coverage_sym: Sym = Sym("par") + self.pairs: list[ParallelAccessPair] = [] + self.parent_state: "CoverageState" = parent_state + + def enter_loop(self, loop: LoopIR.For, loop_entrance_placeholder: int): + if isinstance(loop.loop_mode, LoopIR.Par): + self.parallel_scopes.append( + ParallelScope(loop.iter, loop_entrance_placeholder) + ) + + def exit_loop_body(self, loop: LoopIR.For): + if isinstance(loop.loop_mode, LoopIR.Par): + scope = self.parallel_scopes.pop() + loop_tail_placeholder = ( + self.parent_state.parent_transpiler._make_placeholder() + ) + for sym, sym_writes in scope.writes.items(): + for sym_write_idx, sym_write in enumerate(sym_writes): + for other_access in chain( + sym_writes[sym_write_idx:], + (scope.reads[sym] if sym in scope.reads else []), + ): + self.pairs.append( + ParallelAccessPair( + self.coverage_sym, + scope.iter_sym, + sym_write, + other_access, + ( + IndexedFiller( + loop_tail_placeholder, + f"{repr(scope.access_set_syms[sym])}_pw={repr(scope.access_set_syms[sym])}_pw.union({repr(scope.access_set_syms[sym])}_cw);", + ), + IndexedFiller( + loop_tail_placeholder, + f"{repr(scope.access_set_syms[sym])}_pr={repr(scope.access_set_syms[sym])}_pr.union({repr(scope.access_set_syms[sym])}_cr);", + ), + ), + ) + ) + + def access_tensor( + self, + js_tensor: Tensor, + js_idxs: tuple[str, ...], + symbolic_idxs: tuple[Expression], + access_placeholder: int, + _access_cursor: Node, + cov_node: CoverageSkeletonNode, + is_write: bool, + ): + js_access = "+".join( + f"Math.imul({js_idx},{js_dim.stride})" + for js_idx, js_dim in zip(js_idxs, js_tensor.dims) + ) + for parallel_scope in self.parallel_scopes: + dest = parallel_scope.writes if is_write else parallel_scope.reads + if js_tensor.name not in dest: + dest[js_tensor.name] = [] + if js_tensor.name not in parallel_scope.access_set_syms: + parallel_scope.access_set_syms[js_tensor.name] = Sym("access_set") + access_set_sym = parallel_scope.access_set_syms[js_tensor.name] + dest[js_tensor.name].append( + ParallelAccess( + cov_node, + symbolic_idxs, + ( + IndexedFiller( + parallel_scope.loop_entrance_placeholder, + "".join( + ( + f"let {repr(access_set_sym)}_pr=new Set();", + f"let {repr(access_set_sym)}_pw=new Set();", + f"let {repr(access_set_sym)}_cr=new Set();", + f"let {repr(access_set_sym)}_cw=new Set();", + ) + ), + ), + IndexedFiller( + self.parent_state.cov_placeholder, + f"let {repr(self.coverage_sym)}=false;", + ), + IndexedFiller( + access_placeholder, + "".join( + ( + f"{repr(access_set_sym)}{'_cw' if is_write else '_cr'}.add({js_access});", + f"if({repr(access_set_sym)}_pw.has({js_access})){{{repr(self.coverage_sym)}=true;return [1,{CONTEXT_OBJECT_NAME},{{}}];}}", + *( + ( + f"if({repr(access_set_sym)}_pr.has({js_access})){{{repr(self.coverage_sym)}=true;return [1,{CONTEXT_OBJECT_NAME},{{}}];}}", + ) + if is_write + else () + ), + ), + ), + ), + ), + ) + ) + + def access_scalar( + self, + name: Sym, + access_placeholder: int, + _access_cursor: Node, + cov_node: CoverageSkeletonNode, + is_write: bool, + ): + self.access_tensor( + Tensor(name, (Dimension("1", "0", Slice("0", "1")),)), + ("0",), + tuple(), + access_placeholder, + _access_cursor, + cov_node, + is_write, + ) + + def make_parallel_access_pairs(self) -> tuple[ParallelAccessPair, ...]: + return tuple(self.pairs) + + +@dataclass +class CoverageArgs: + cm: ConstraintMaker + var_renaming: dict[Sym, Sym] + failure_scope: Optional[Block] = None + stage_mem_args: Optional[StageMemArgs] = None + + +class CoverageState: + def __init__(self, args: CoverageArgs, parent_transpiler: "Transpiler"): + self.cm: ConstraintMaker = args.cm + self.var_renaming: dict[Sym, Sym] = args.var_renaming + self.parent_transpiler: Transpiler = parent_transpiler + self.cov_placeholder: int = parent_transpiler._make_placeholder() + self.root: CoverageSkeletonNode = CoverageSkeletonNode(None, None, ()) + self.current_node: CoverageSkeletonNode = self.root + self.symbolic_tensors: dict[Sym, SymbolicTensor] = {} + self.scalar_symbols: dict[Sym, Sym] = {} + self.ctxt_symbols: dict[tuple[Config, str], Sym] = {} + self.aliasing_tracker = AliasingTracker(self) + self.failure_tracker: Optional[FailureTracker] = ( + None + if args.failure_scope is None + else FailureTracker(args.failure_scope, self) + ) + self.stage_mem_tracker: Optional[StageMemTracker] = ( + None + if args.stage_mem_args is None + else StageMemTracker(args.stage_mem_args, self) + ) + self.parallel_access_tracker = ParallelAccessTracker(self) + self.free_vars: list[Sym] = [] + + def enter_loop( + self, + stmt: LoopIR.For, + lo_js: str, + hi_js: str, + transpile_loop_body: Callable[[], None], + ): + body_sym = Sym("body") + skip_sym = Sym("skip") + loop_entrance_placeholder = self.parent_transpiler._make_placeholder() + body_constraint = ( + self.cm.make_constraint_from_inequality( + stmt.lo, stmt.iter, "<=", self.var_renaming + ) + .lift_to_disjoint_constraint() + .intersect( + self.cm.make_constraint_from_inequality( + stmt.iter, stmt.hi, "<", self.var_renaming + ).lift_to_disjoint_constraint() + ) + ) + skip_constraint = self.cm.make_constraint_from_inequality( + stmt.lo, stmt.hi, ">=", self.var_renaming + ).lift_to_disjoint_constraint() + parent_node = self.current_node + body_child = CoverageSkeletonNode( + body_sym, + (parent_node, body_constraint), + ( + IndexedFiller( + self.cov_placeholder, + f"let {repr(body_sym)}=false;", + ), + IndexedFiller( + loop_entrance_placeholder, + f"{repr(body_sym)}||=({lo_js}<{hi_js});", + ), + ), + ) + skip_child = CoverageSkeletonNode( + skip_sym, + (parent_node, skip_constraint), + ( + IndexedFiller( + self.cov_placeholder, + f"let {repr(skip_sym)}=false;", + ), + IndexedFiller( + loop_entrance_placeholder, + f"{repr(skip_sym)}||=({lo_js}>={hi_js});", + ), + ), + ) + self.current_node.branches.append( + CoverageSkeletonBranch(body_child, skip_child) + ) + self.parallel_access_tracker.enter_loop(stmt, loop_entrance_placeholder) + self.free_vars.append(stmt.iter) + self.current_node = body_child + transpile_loop_body() + self.current_node = parent_node + self.parallel_access_tracker.exit_loop_body(stmt) + + def enter_if( + self, + stmt: LoopIR.If, + transpile_if_body: Callable[[], None], + transpile_else_body: Callable[[], None], + ): + parent_node = self.current_node + true_sym = Sym("true_case") + false_sym = Sym("false_case") + true_placeholder = self.parent_transpiler._make_placeholder() + cond_constraint = self.cm.make_constraint(stmt.cond, self.var_renaming) + true_node = CoverageSkeletonNode( + true_sym, + (parent_node, cond_constraint), + ( + IndexedFiller( + self.cov_placeholder, + f"let {repr(true_sym)}=false;", + ), + IndexedFiller(true_placeholder, f"{repr(true_sym)}=true;"), + ), + ) + self.current_node = true_node + transpile_if_body() + false_placeholder = self.parent_transpiler._make_placeholder() + false_node = CoverageSkeletonNode( + false_sym, + (parent_node, cond_constraint.invert()), + ( + IndexedFiller( + self.cov_placeholder, + f"let {repr(false_sym)}=false;", + ), + IndexedFiller(false_placeholder, f"{repr(false_sym)}=true;"), + ), + ) + self.current_node = false_node + transpile_else_body() + self.current_node = parent_node + new_branch = CoverageSkeletonBranch(true_node, false_node) + self.current_node.branches.append(new_branch) + + def enter_stmt(self, stmt_cursor: Node): + if self.failure_tracker is not None: + self.failure_tracker.enter_stmt(stmt_cursor) + if self.stage_mem_tracker is not None: + self.stage_mem_tracker.enter_stmt(stmt_cursor) + + def exit_stmt(self, stmt_cursor: Node): + if self.failure_tracker is not None: + self.failure_tracker.exit_stmt(stmt_cursor) + if self.stage_mem_tracker is not None: + self.stage_mem_tracker.exit_stmt(stmt_cursor) + + def enter_proc_body(self): + if self.failure_tracker is not None: + self.failure_tracker.enter_proc_body() + + def exit_proc_body(self): + if self.failure_tracker is not None: + self.failure_tracker.exit_proc_body() + + def assert_shape_matches( + self, tensor_sym: Sym, shape: list[LoopIR.expr], shape_matches_js: str + ): + match_cond = TRUE_CONSTRAINT + for tensor_dim, shape_dim in zip( + ( + dim + for dim in self.symbolic_tensors[tensor_sym].dims + if isinstance(dim, SymbolicSlice) + ), + shape, + ): + match_cond = match_cond.intersect( + Constraint( + self.cm.make_expression(shape_dim, self.var_renaming) + .negate() + .add(tensor_dim.upper_bound) + .add(tensor_dim.lower_bound.negate()), + False, + ).lift_to_disjoint_constraint(), + ) + if self.failure_tracker is not None: + self.failure_tracker.add_assertion( + match_cond, shape_matches_js, self.parent_transpiler._make_placeholder() + ) + + def assert_predicate(self, pred: LoopIR.expr, js_pred: str): + if self.failure_tracker is not None: + self.failure_tracker.add_assertion( + self.cm.make_constraint(pred, self.var_renaming), + js_pred, + self.parent_transpiler._make_placeholder(), + ) + + def make_tensor(self, sym: Sym, dims: list[LoopIR.expr], nonnegative_dims_js: str): + symbolic_dims = tuple( + self.cm.make_expression(dim, self.var_renaming) for dim in dims + ) + nonnegative_constraint = TRUE_CONSTRAINT + for symbolic_dim in symbolic_dims: + nonnegative_constraint = nonnegative_constraint.intersect( + Constraint(symbolic_dim, True).lift_to_disjoint_constraint() + ) + if self.failure_tracker is not None: + self.failure_tracker.add_assertion( + nonnegative_constraint, + nonnegative_dims_js, + self.parent_transpiler._make_placeholder(), + ) + self.symbolic_tensors[sym] = SymbolicTensor( + sym, + tuple( + SymbolicSlice(Expression.from_constant(0), symbolic_dim) + for symbolic_dim in symbolic_dims + ), + self.parent_transpiler._make_placeholder(), + ) + + def make_scalar(self, sym: Sym): + self.scalar_symbols[sym] = sym + + def assign_tensor(self, arg_name: Sym, original_name: Sym): + self.symbolic_tensors[arg_name] = self.symbolic_tensors[original_name] + + def assign_scalar(self, arg_name: Sym, original_name: Sym): + self.scalar_symbols[arg_name] = self.scalar_symbols[original_name] + + def assign_scalar_from_context(self, scalar_sym: Sym, config: Config, field: str): + config_key = (config, field) + if config_key not in self.ctxt_symbols: + self.ctxt_symbols[config_key] = Sym("ctxt") + self.scalar_symbols[scalar_sym] = self.ctxt_symbols[config_key] + + def assign_window( + self, sym: Sym, source_buf: Sym, access_cursor: Block, in_bounds_js: str + ): + base_tensor = self.symbolic_tensors[source_buf] + in_bounds_constraint = TRUE_CONSTRAINT + window_dims = [] + window_idx_iter = iter(access_cursor) + for dim in base_tensor.dims: + if isinstance(dim, SymbolicPoint): + window_dims.append(dim) + else: + idx = next(window_idx_iter)._node + if isinstance(idx, LoopIR.Interval): + new_dim = SymbolicSlice( + self.cm.make_expression(idx.lo, self.var_renaming).add( + dim.lower_bound + ), + self.cm.make_expression(idx.hi, self.var_renaming).add( + dim.lower_bound + ), + ) + in_bounds_constraint = in_bounds_constraint.intersect( + Constraint( + new_dim.lower_bound.add(dim.lower_bound.negate()), True + ).lift_to_disjoint_constraint() + ).intersect( + Constraint( + dim.upper_bound.add(new_dim.upper_bound.negate()), True + ).lift_to_disjoint_constraint() + ) + window_dims.append(new_dim) + else: + new_dim = SymbolicPoint( + self.cm.make_expression(idx.pt, self.var_renaming).add( + dim.lower_bound + ) + ) + in_bounds_constraint = in_bounds_constraint.intersect( + Constraint( + new_dim.index.add(dim.lower_bound.negate()), True + ).lift_to_disjoint_constraint() + ).intersect( + Constraint( + dim.upper_bound.add(new_dim.index.negate()).add( + Expression.from_constant(-1) + ), + True, + ).lift_to_disjoint_constraint() + ) + window_dims.append(new_dim) + + if self.failure_tracker is not None: + self.failure_tracker.add_assertion( + in_bounds_constraint, + in_bounds_js, + self.parent_transpiler._make_placeholder(), + ) + self.symbolic_tensors[sym] = SymbolicTensor( + base_tensor.name, + tuple(window_dims), + base_tensor.resize_placeholder, + ) + + def access_tensor( + self, + access_cursor: Node, + js_idxs: tuple[str, ...], + is_write: bool, + in_bounds_js: str, + ): + js_tensor = self.parent_transpiler._lookup_sym(access_cursor._node.name) + assert isinstance(js_tensor, Tensor) + symbolic_tensor = self.symbolic_tensors[access_cursor._node.name] + idx_expr_iter = iter(access_cursor._node.idx) + symbolic_idxs = [] + in_bounds_constraint = TRUE_CONSTRAINT + for dim in symbolic_tensor.dims: + if isinstance(dim, SymbolicSlice): + idx = self.cm.make_expression( + next(idx_expr_iter), self.var_renaming + ).add(dim.lower_bound) + in_bounds_constraint = in_bounds_constraint.intersect( + Constraint( + idx.negate().add(dim.upper_bound), True + ).lift_to_disjoint_constraint() + ) + symbolic_idxs.append(idx) + else: + symbolic_idxs.append(dim.index) + access_placeholder = self.parent_transpiler._make_placeholder() + access_args = ( + js_tensor, + js_idxs, + tuple(symbolic_idxs), + access_placeholder, + access_cursor, + self.current_node, + is_write, + ) + if self.failure_tracker is not None: + self.failure_tracker.add_assertion( + in_bounds_constraint, + in_bounds_js, + self.parent_transpiler._make_placeholder(), + ) + self.aliasing_tracker.access_tensor(*access_args) + if self.stage_mem_tracker is not None: + self.stage_mem_tracker.access_tensor(*access_args) + self.parallel_access_tracker.access_tensor(*access_args) + + def access_scalar(self, access_cursor: Node, is_write: bool): + access_placeholder = self.parent_transpiler._make_placeholder() + scalar_sym = self.scalar_symbols[access_cursor._node.name] + access_args = ( + scalar_sym, + access_placeholder, + access_cursor, + self.current_node, + is_write, + ) + self.aliasing_tracker.access_scalar(*access_args) + self.parallel_access_tracker.access_scalar(*access_args) + + def access_context(self, access_cursor: Node, is_write: bool): + access_placeholder = self.parent_transpiler._make_placeholder() + config_key = (access_cursor._node.config, access_cursor._node.field) + if config_key not in self.ctxt_symbols: + self.ctxt_symbols[config_key] = Sym("ctxt") + ctxt_sym = self.ctxt_symbols[config_key] + access_args = ( + ctxt_sym, + access_placeholder, + access_cursor, + self.current_node, + is_write, + ) + self.aliasing_tracker.access_scalar(*access_args) + self.parallel_access_tracker.access_scalar(*access_args) + + def make_skeleton(self) -> CoverageSkeleton: + return CoverageSkeleton( + (self.root,), + self.aliasing_tracker.make_aliasing_accesses(), + ( + () + if self.failure_tracker is None + else self.failure_tracker.make_failures() + ), + ( + () + if self.stage_mem_tracker is None + else self.stage_mem_tracker.make_staging_overlaps() + ), + ( + () + if self.stage_mem_tracker is None + else self.stage_mem_tracker.make_staging_bound_checks() + ), + self.parallel_access_tracker.make_parallel_access_pairs(), + frozenset(self.free_vars), + ) + + +def get_shape_cursor(type_cursor: Node) -> Block: + return ( + type_cursor._child_node("as_tensor") + if isinstance(type_cursor._node, LoopIR.WindowType) + else type_cursor + )._child_block("hi") + + +class Transpiler: + def __init__(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArgs] = None): + self._name_lookup: dict[Sym, ExoValue] = {} + self._js_lines: list[str] = [] + self._configs: set[tuple[Config, str]] = set() + self._buffer_args: list[Sym] = [] + self._coverage_state: Optional[CoverageState] = None + self._skeleton: Optional[CoverageSkeleton] = None + self._proc = proc + self._transpile_proc(proc, coverage_args) + + def get_javascript_template(self) -> Template: + return Template("\n".join(self._js_lines)) + + def get_configs(self) -> frozenset[tuple[Config, str]]: + return frozenset(self._configs) + + def get_buffer_arg_order(self) -> tuple[Sym, ...]: + return tuple(self._buffer_args) + + def get_config_param_name(self, config: Config, field: str) -> str: + return f"config_{config.name()}_{field}" + + def get_stride_param_name(self, tensor_name: Sym, dim_idx: int): + return f"stride_{repr(tensor_name)}_{dim_idx}" + + def get_size_param_name(self, tensor_name: Sym, dim_idx: int): + return f"size_{repr(tensor_name)}_{dim_idx}" + + def get_coverage_skeleton(self) -> Optional[CoverageSkeleton]: + return self._skeleton + + def get_proc(self) -> LoopIR.proc: + return self._proc + + def _assert_at_runtime(self, expr: str): + self._js_lines.append(f"if(!{expr})return [1,{CONTEXT_OBJECT_NAME},{{}}];") + + # for CoverageState + def _make_placeholder(self) -> int: + placeholder_index = len(self._js_lines) + self._js_lines.append("") + return placeholder_index + + # for CoverageState + def _lookup_sym(self, sym: Sym) -> ExoValue: + return self._name_lookup[sym] + + def _transpile_proc(self, proc: LoopIR.proc, coverage_args: Optional[CoverageArgs]): + self._buffer_args = [arg.name for arg in proc.args if arg.type.is_numeric()] + root_cursor = Cursor.create(proc) + self._js_lines.append( + f'(({",".join(repr(arg) for arg in self._buffer_args)})=>{{' + ) + ctxt_placeholder = self._make_placeholder() + if coverage_args is not None: + self._coverage_state = CoverageState(coverage_args, self) + arg_values = [] + for arg_cursor in root_cursor._child_block("args"): + arg = arg_cursor._node + if arg.type.is_numeric(): + if isinstance(arg.type, LoopIR.Tensor): + value = Tensor( + arg.name, + tuple( + Dimension( + f"${size}", + f"${stride}", + Slice("0", f"${size}"), + ) + for size, stride in map( + lambda dim_idx: ( + self.get_size_param_name(arg.name, dim_idx), + self.get_stride_param_name(arg.name, dim_idx), + ), + range(len(arg.type.shape())), + ) + ), + ) + if self._coverage_state is not None: + self._coverage_state.make_tensor( + arg.name, arg.type.shape(), "true" + ) + elif isinstance(arg.type, LoopIR.WindowType): + value = self._transpile_window( + arg.name, + arg.type.src_buf, + arg_cursor._child_node("type")._child_block("idx"), + ) + else: + value = Reference(repr(arg.name), False) + if self._coverage_state is not None: + self._coverage_state.make_scalar(arg.name) + else: + value = Constant(f"${repr(arg.name)}") + arg_values.append(value) + self._call_proc(root_cursor, tuple(arg_values), True) + coverage_object = "" + if self._coverage_state is not None: + skeleton = self._coverage_state.make_skeleton() + self._skeleton = skeleton + coverage_object = f"{{{','.join(sorted(repr(sym) for sym in self._skeleton.get_coverage_syms()))}}}" + for indexed_filler in sorted(set(skeleton.get_indexed_fillers())): + self._js_lines[indexed_filler.index] += indexed_filler.placefiller + self._js_lines.append(f"return [0,{CONTEXT_OBJECT_NAME},{coverage_object}];}})") + configs = ",".join( + f"{self.get_config_param_name(config, field)}:${self.get_config_param_name(config, field)}" + for config, field in self._configs + ) + self._js_lines[ctxt_placeholder] = f"{CONTEXT_OBJECT_NAME}={{{configs}}}" + + def _transpile_window( + self, name: Sym, source_buf: Sym, access_cursor: Block + ) -> Tensor: + base = self._name_lookup[source_buf] + assert isinstance(base, Tensor) + window_dims = [] + in_bounds_conds = [] + idx_cursor_iter = iter(access_cursor) + for dim in base.dims: + if isinstance(dim.window_idx, Point): + window_dims.append(dim) + else: + idx_cursor = next(idx_cursor_iter) + idx = idx_cursor._node + if isinstance(idx, LoopIR.Interval): + lo_expr = self._transpile_expr( + idx_cursor._child_node("lo"), + ) + hi_expr = self._transpile_expr( + idx_cursor._child_node("hi"), + ) + lo_sym, hi_sym = Sym("lo"), Sym("hi") + self._js_lines.append( + f"let {repr(lo_sym)}=({lo_expr})+({dim.window_idx.lower_bound});" + ) + self._js_lines.append( + f"let {repr(hi_sym)}=({hi_expr})+({dim.window_idx.lower_bound});" + ) + in_bounds_conds.append( + f"(({dim.window_idx.lower_bound})<=({repr(lo_sym)})&&({repr(lo_sym)})<=({repr(hi_sym)})&&({repr(hi_sym)})<=({dim.window_idx.upper_bound}))" + ) + window_dims.append( + Dimension( + dim.size, dim.stride, Slice(repr(lo_sym), repr(hi_sym)) + ) + ) + elif isinstance(idx, LoopIR.Point): + pt_expr = self._transpile_expr( + idx_cursor._child_node("pt"), + ) + index_sym = Sym("idx") + self._js_lines.append( + f"let {repr(index_sym)}=({pt_expr})+({dim.window_idx.lower_bound});" + ) + in_bounds_conds.append( + f"(({dim.window_idx.lower_bound})<={repr(index_sym)}&&{repr(index_sym)}<({dim.window_idx.upper_bound}))" + ) + window_dims.append( + Dimension(dim.size, dim.stride, Point(repr(index_sym))) + ) + else: + assert False, "not a window index" + in_bounds_js = "&&".join(in_bounds_conds) + if self._coverage_state is not None: + self._coverage_state.assign_window( + name, source_buf, access_cursor, in_bounds_js + ) + self._assert_at_runtime(in_bounds_js) + return Tensor(base.name, tuple(window_dims)) + + def _call_proc( + self, proc_cursor: Node, arg_values: tuple[ExoValue, ...], top_level: bool + ): + for arg_cursor, arg_value in zip(proc_cursor._child_block("args"), arg_values): + arg = arg_cursor._node + self._name_lookup[arg.name] = arg_value + if arg.type.is_tensor_or_window(): + assert isinstance(arg_value, Tensor) + shape_matches_js = "&&".join( + f"((({arg_dim_slice.upper_bound})-({arg_dim_slice.lower_bound}))==({self._transpile_expr(arg_dim_cursor)}))" + for arg_dim_slice, arg_dim_cursor in zip( + ( + dim.window_idx + for dim in arg_value.dims + if isinstance(dim.window_idx, Slice) + ), + get_shape_cursor(arg_cursor._child_node("type")), + ) + ) + if self._coverage_state is not None and not top_level: + self._coverage_state.assert_shape_matches( + arg.name, arg.type.shape(), shape_matches_js + ) + self._assert_at_runtime(shape_matches_js) + + for pred_cursor in proc_cursor._child_block("preds"): + js_pred = self._transpile_expr(pred_cursor) + if self._coverage_state is not None and not top_level: + self._coverage_state.assert_predicate(pred_cursor._node, js_pred) + self._assert_at_runtime(js_pred) + + if self._coverage_state is not None: + self._coverage_state.enter_proc_body() + for stmt_cursor in proc_cursor._child_block("body"): + self._transpile_stmt(stmt_cursor) + if self._coverage_state is not None: + self._coverage_state.exit_proc_body() + + def _get_index_exprs( + self, + buf: Tensor, + idxs: Block, + ) -> tuple[str, ...]: + idx_expr_iter = iter(self._transpile_expr(idx) for idx in idxs) + idx_parts = [] + for dim in buf.dims: + if isinstance(dim.window_idx, Slice): + idx_expr = next(idx_expr_iter) + idx_parts.append(f"(({idx_expr})+({dim.window_idx.lower_bound}))") + else: + idx_parts.append(f"({dim.window_idx.index})") + return tuple(idx_parts) + + def _get_in_bounds_condition( + self, index_exprs: tuple[str, ...], buf: Tensor + ) -> str: + return "&&".join( + f"(({index_expr})>=({dim.window_idx.lower_bound})&&({index_expr})<({dim.window_idx.upper_bound}))" + for index_expr, dim in zip(index_exprs, buf.dims) + if isinstance(dim.window_idx, Slice) + ) + + def _transpile_stmt( + self, + stmt_cursor: Node, + ): + if self._coverage_state is not None: + self._coverage_state.enter_stmt(stmt_cursor) + stmt = stmt_cursor._node + if isinstance(stmt, (LoopIR.Assign, LoopIR.Reduce)): + lhs_buffer = self._name_lookup[stmt.name] + rhs = self._transpile_expr(stmt_cursor._child_node("rhs")) + if isinstance(lhs_buffer, Reference): + lhs = ( + lhs_buffer.name if lhs_buffer.is_config else f"{lhs_buffer.name}[0]" + ) + if self._coverage_state is not None: + self._coverage_state.access_scalar(stmt_cursor, True) + elif isinstance(lhs_buffer, Tensor): + index_exprs = self._get_index_exprs( + lhs_buffer, + stmt_cursor._child_block("idx"), + ) + index = f"+".join( + f"Math.imul({dim.stride},{index_expr})" + for dim, index_expr in zip(lhs_buffer.dims, index_exprs) + ) + lhs = f"{repr(lhs_buffer.name)}[{index}]" + in_bounds_js = self._get_in_bounds_condition(index_exprs, lhs_buffer) + if self._coverage_state is not None: + self._coverage_state.access_tensor( + stmt_cursor, index_exprs, True, in_bounds_js + ) + self._assert_at_runtime(in_bounds_js) + else: + assert False + if isinstance(stmt, LoopIR.Assign): + self._js_lines.append(f"{lhs}={rhs};") + else: + self._js_lines.append(f"{lhs}+={rhs};") + elif isinstance(stmt, LoopIR.WriteConfig): + config_name = self.get_config_param_name(stmt.config, stmt.field) + rhs = self._transpile_expr(stmt_cursor._child_node("rhs")) + self._js_lines.append(f'{CONTEXT_OBJECT_NAME}["{config_name}"]={rhs};') + self._configs.add((stmt.config, stmt.field)) + if self._coverage_state is not None: + self._coverage_state.access_context(stmt_cursor, True) + elif isinstance(stmt, LoopIR.Pass): + pass + elif isinstance(stmt, LoopIR.If): + cond = self._transpile_expr(stmt_cursor._child_node("cond")) + self._js_lines.append(f"if({cond}){{") + + def transpile_if_body(): + for body_cursor in stmt_cursor._child_block("body"): + self._transpile_stmt(body_cursor) + self._js_lines.append("}else{") + + def transpile_else_body(): + for else_cursor in stmt_cursor._child_block("orelse"): + self._transpile_stmt(else_cursor) + self._js_lines.append("}") + + if self._coverage_state is not None: + self._coverage_state.enter_if( + stmt, transpile_if_body, transpile_else_body + ) + else: + transpile_if_body() + transpile_else_body() + elif isinstance(stmt, LoopIR.For): + iter_name = repr(stmt.iter) + iter_lo = self._transpile_expr(stmt_cursor._child_node("lo")) + iter_hi = self._transpile_expr(stmt_cursor._child_node("hi")) + self._name_lookup[stmt.iter] = Constant(iter_name) + + def transpile_loop_body(): + self._js_lines.append( + f"for(let {iter_name}={iter_lo};{iter_name}<{iter_hi};{iter_name}++){{" + ) + for body_cursor in stmt_cursor._child_block("body"): + self._transpile_stmt(body_cursor) + + if self._coverage_state is not None: + self._coverage_state.enter_loop( + stmt, iter_lo, iter_hi, transpile_loop_body + ) + else: + transpile_loop_body() + + self._js_lines.append("}") + elif isinstance(stmt, LoopIR.Alloc): + assert stmt.type.is_numeric() + if stmt.type.is_tensor_or_window(): + tensor_name: Sym = stmt.name + buffer_type = lookup_loopir_type( + stmt.type.basetype() + ).javascript_array_type + shape_cursor = get_shape_cursor(stmt_cursor._child_node("type")) + dims = len(shape_cursor) + self._js_lines.append( + "".join( + f"let {self.get_size_param_name(tensor_name, dim_idx)}={self._transpile_expr(dim_cursor)};" + for dim_idx, dim_cursor in enumerate(shape_cursor) + ) + ) + self._js_lines.append( + "".join( + f'let {self.get_stride_param_name(tensor_name, dim_idx)}={f"1" if dim_idx + 1 == dims else f"Math.imul({self.get_stride_param_name(tensor_name, dim_idx + 1)},{self.get_size_param_name(tensor_name,dim_idx + 1)})"};' + for dim_idx in reversed(range(dims)) + ) + ) + nonnegative_dims_js = "&&".join( + f"({self.get_size_param_name(tensor_name, dim_idx)}>=0)" + for dim_idx in range(dims) + ) + if self._coverage_state is not None: + self._coverage_state.make_tensor( + tensor_name, stmt.type.shape(), nonnegative_dims_js + ) + self._assert_at_runtime(nonnegative_dims_js) + buffer_size = f"Math.imul({self.get_size_param_name(tensor_name, 0)},{self.get_stride_param_name(tensor_name, 0)})" + self._js_lines.append( + f"let {repr(tensor_name)}=new {buffer_type}({buffer_size});for(let i=0;i<{buffer_size};i++){{{repr(tensor_name)}[i]=(Math.random()-0.5)*(1<<30);}}" + ) + self._name_lookup[stmt.name] = Tensor( + stmt.name, + tuple( + Dimension(size, stride, Slice("0", size)) + for size, stride in map( + lambda dim_idx: ( + self.get_size_param_name(tensor_name, dim_idx), + self.get_stride_param_name(tensor_name, dim_idx), + ), + range(dims), + ) + ), + ) + else: + ref_name = repr(stmt.name) + buffer_type = lookup_loopir_type(stmt.type).javascript_array_type + if self._coverage_state is not None: + self._coverage_state.make_scalar(stmt.name) + self._js_lines.append(f"let {ref_name}=new {buffer_type}(1);") + self._name_lookup[stmt.name] = Reference(ref_name, False) + elif isinstance(stmt, LoopIR.Free): + pass + elif isinstance(stmt, LoopIR.Call): + self._call_proc( + stmt_cursor._child_node("f"), + tuple( + ( + self._transpile_buffer_arg( + arg_val_cursor, arg_name_cursor._node.name + ) + if arg_val_cursor._node.type.is_numeric() + else Constant(self._transpile_expr(arg_val_cursor)) + ) + for arg_val_cursor, arg_name_cursor in zip( + stmt_cursor._child_block("args"), + stmt_cursor._child_node("f")._child_block("args"), + ) + ), + False, + ) + elif isinstance(stmt, LoopIR.WindowStmt): + self._name_lookup[stmt.name] = self._transpile_buffer_arg( + stmt_cursor._child_node("rhs"), stmt.name + ) + else: + assert False, "unsupported stmt" + + if self._coverage_state is not None: + self._coverage_state.exit_stmt(stmt_cursor) + + def _transpile_buffer_arg( + self, expr_cursor: Node, new_name: Sym + ) -> Union[Tensor, Reference]: + expr = expr_cursor._node + if isinstance(expr, LoopIR.Read): + assert len(expr.idx) == 0 + buf = self._name_lookup[expr.name] + assert isinstance(buf, (Tensor, Reference)) + if self._coverage_state is not None: + if isinstance(buf, Tensor): + self._coverage_state.assign_tensor(new_name, expr.name) + else: + self._coverage_state.assign_scalar(new_name, expr.name) + return buf + elif isinstance(expr, LoopIR.WindowExpr): + return self._transpile_window( + new_name, expr.name, expr_cursor._child_block("idx") + ) + elif isinstance(expr, LoopIR.ReadConfig): + self._configs.add((expr.config, expr.field)) + if self._coverage_state is not None: + self._coverage_state.assign_scalar_from_context( + new_name, expr.config, expr.field + ) + return Reference( + f'{CONTEXT_OBJECT_NAME}["{self.get_config_param_name(expr.config, expr.field)}"]', + True, + ) + else: + assert False, "unsupported buffer expression" + + def _transpile_expr( + self, + expr_cursor: Node, + ) -> str: + expr = expr_cursor._node + if isinstance(expr, LoopIR.Read): + buf = self._name_lookup[expr.name] + if isinstance(buf, Tensor): + index_exprs = self._get_index_exprs( + buf, + expr_cursor._child_block("idx"), + ) + index = f"+".join( + f"Math.imul({dim.stride},{index_expr})" + for dim, index_expr in zip(buf.dims, index_exprs) + ) + in_bounds_js = self._get_in_bounds_condition(index_exprs, buf) + if self._coverage_state is not None: + self._coverage_state.access_tensor( + expr_cursor, index_exprs, False, in_bounds_js + ) + self._assert_at_runtime(in_bounds_js) + return f"{repr(buf.name)}[{index}]" + elif isinstance(buf, Reference): + if self._coverage_state is not None: + self._coverage_state.access_scalar(expr_cursor, False) + return buf.name if buf.is_config else f"{buf.name}[0]" + else: + return buf.name + elif isinstance(expr, LoopIR.Const): + if isinstance(expr.val, bool): + return "true" if expr.val else "false" + elif isinstance(expr.val, (int, float)): + return f"{expr.val}" + else: + assert False, "unexpected const type" + elif isinstance(expr, LoopIR.USub): + return f"(-{self._transpile_expr(expr_cursor._child_node('arg'))})" + elif isinstance(expr, LoopIR.BinOp): + lhs = self._transpile_expr(expr_cursor._child_node("lhs")) + rhs = self._transpile_expr(expr_cursor._child_node("rhs")) + is_int = ( + isinstance(expr.type, (T.INT8, T.UINT8, T.UINT16, T.INT32)) + or not expr.type.is_numeric() + ) + if expr.op in ["+", "-", "%", "<", ">", "<=", ">=", "=="]: + val = f"({lhs}{expr.op}{rhs})" + elif expr.op == "*": + val = f"Math.imul({lhs},{rhs})" if is_int else f"({lhs}*{rhs})" + elif expr.op == "/": + val = f"(({lhs}/{rhs})|0)" if is_int else f"({lhs}/{rhs})" + elif expr.op == "and": + val = f"({lhs}&&{rhs})" + elif expr.op == "or": + val = f"({lhs}||{rhs})" + else: + assert False, "invalid op" + if isinstance(expr.type, T.INT8): + return f"(({val}<<24)>>24)" + elif isinstance(expr.type, T.UINT8): + return f"({val}&0xFF)" + elif isinstance(expr.type, T.UINT16): + return f"({val}&0xFFFF)" + else: + return val + elif isinstance(expr, LoopIR.Extern): + return expr.f.transpile( + tuple( + self._transpile_expr(arg_cursor) + for arg_cursor in expr_cursor._child_block("args") + ) + ) + elif isinstance(expr, LoopIR.WindowExpr): + assert False, "unexpected window expr" + elif isinstance(expr, LoopIR.StrideExpr): + buf = self._name_lookup[expr.name] + assert isinstance(buf, Tensor) + return tuple(dim for dim in buf.dims if isinstance(dim.window_idx, Slice))[ + expr.dim + ].stride + elif isinstance(expr, LoopIR.ReadConfig): + self._configs.add((expr.config, expr.field)) + if self._coverage_state is not None: + self._coverage_state.access_context(expr_cursor, False) + return f'{CONTEXT_OBJECT_NAME}["{self.get_config_param_name(expr.config, expr.field)}"]' + else: + assert False, "unexpected expr" diff --git a/src/exo/rewrite/chexo/__init__.py b/src/exo/rewrite/chexo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/exo/rewrite/chexo/chexo.py b/src/exo/rewrite/chexo/chexo.py new file mode 100644 index 000000000..c983ba63d --- /dev/null +++ b/src/exo/rewrite/chexo/chexo.py @@ -0,0 +1,818 @@ +from typing import Literal, Optional, Union + +from ...core.internal_cursors import Block, Node, NodePath + +from .LoopIR_transpiler import CoverageArgs, StageMemArgs, Transpiler +from .coverage import CoverageSkeleton + +from ...core.configs import Config + +from ...core.LoopIR import LoopIR, T +from dataclasses import dataclass, field +from ...core.prelude import Sym, SrcInfo +from ...core.memory import DRAM, Memory +import numpy as np +from ..new_eff import SchedulingError +from .constraint_solver import ( + TRUE_CONSTRAINT, + ConstraintMaker, + DisjointConstraint, +) + +from pythonmonkey import eval as js_eval + + +class LoopIRVisitor: + def visit(self, node): + self.visit_generic(node) + + def visit_generic(self, node): + if ( + isinstance(node, LoopIR.proc) + or isinstance(node, LoopIR.instr) + or isinstance(node, LoopIR.fnarg) + or isinstance(node, LoopIR.stmt) + or isinstance(node, LoopIR.loop_mode) + or isinstance(node, LoopIR.expr) + or isinstance(node, LoopIR.w_access) + or isinstance(node, LoopIR.type) + ): + for field_name in dir(node): + if not field_name.startswith("_"): + field = getattr(node, field_name) + if isinstance(field, list): + for child in field: + self.visit(child) + else: + self.visit(field) + + +@dataclass +class TypeVisitor(LoopIRVisitor): + type_map: dict[Sym, LoopIR.type] = field(default_factory=lambda: {}) + mem_map: dict[Sym, Memory] = field(default_factory=lambda: {}) + + def visit(self, node): + if isinstance(node, LoopIR.For): + self.type_map[node.iter] = T.Index() + self.visit_generic(node) + elif isinstance(node, LoopIR.Alloc): + self.type_map[node.name] = node.type + self.mem_map[node.name] = node.mem + elif isinstance(node, LoopIR.WindowStmt): + self.type_map[node.name] = node.rhs.type + elif isinstance(node, LoopIR.fnarg): + self.type_map[node.name] = node.type + if node.mem: + self.mem_map[node.name] = node.mem + elif isinstance(node, LoopIR.Call): + for arg_val, arg in zip(node.args, node.f.args): + if isinstance(arg.type, LoopIR.Tensor) and arg.type.is_window: + self.type_map[arg.name] = arg_val.type + else: + self.type_map[arg.name] = arg.type + for stmt in node.f.body: + self.visit(stmt) + else: + self.visit_generic(node) + + +@dataclass +class UsedVariableVisitor(LoopIRVisitor): + used_vars: set[Sym] = field(default_factory=lambda: set()) + + def visit(self, node): + if isinstance(node, Sym): + self.used_vars.add(node) + else: + self.visit_generic(node) + + +@dataclass +class ModifiedVariableVisitor(LoopIRVisitor): + type_map: dict[Sym, LoopIR.type] + modified_vars: set[Sym] = field(default_factory=lambda: set()) + + def visit(self, node): + if isinstance(node, (LoopIR.Assign, LoopIR.Reduce)): + node_type = self.type_map[node.name] + if isinstance(node_type, LoopIR.WindowType): + self.modified_vars.add(node_type.src_buf) + else: + self.modified_vars.add(node.name) + else: + self.visit_generic(node) + + +class LoopIRModifier: + def visit(self, node): + return self.visit_generic(node) + + def visit_generic(self, node): + if ( + isinstance(node, LoopIR.proc) + or isinstance(node, LoopIR.instr) + or isinstance(node, LoopIR.fnarg) + or isinstance(node, LoopIR.stmt) + or isinstance(node, LoopIR.loop_mode) + or isinstance(node, LoopIR.expr) + or isinstance(node, LoopIR.w_access) + or isinstance(node, LoopIR.type) + ): + updates = {} + for field_name in dir(node): + if not field_name.startswith("_"): + field = getattr(node, field_name) + if isinstance(field, list): + new_field = field + for child_idx, child in enumerate(field): + new_child = self.visit(child) + if new_child != child: + if new_field == field: + new_field = field.copy() + new_field[child_idx] = new_child + else: + new_field = self.visit(field) + if new_field != field: + updates[field_name] = new_field + return node.update(**updates) if len(updates) != 0 else node + + +@dataclass +class ReadWriteSyms: + reduced_syms: set[Sym] + assigned_syms: set[Sym] + written_configs: set[tuple[Config, str]] + read_syms: set[Sym] + + +@dataclass +class Dimension: + size: int + stride: int + + +@dataclass +class Tensor: + data: np.ndarray + dims: tuple[Dimension, ...] + + +def get_free_variables(type_map, mem_map, fragment: Union[Block, Node]): + fragment_type_visitor = TypeVisitor() + fragment_var_visitor = UsedVariableVisitor() + if isinstance(fragment, Block): + for fragment_node in fragment.resolve_all(): + fragment_type_visitor.visit(fragment_node) + fragment_var_visitor.visit(fragment_node) + else: + fragment_type_visitor.visit(fragment._node) + fragment_var_visitor.visit(fragment._node) + for var in fragment_var_visitor.used_vars - fragment_type_visitor.type_map.keys(): + fragment_var_visitor.visit(type_map[var]) + return { + var: (type_map[var], mem_map[var] if var in mem_map else None) + for var in fragment_var_visitor.used_vars + - fragment_type_visitor.type_map.keys() + } + + +def eval_tensor_dimension( + dim_expr: LoopIR.expr, arg_values: dict[Sym, Union[int, bool, float, Tensor]] +) -> int: + if isinstance(dim_expr, LoopIR.Read): + return arg_values[dim_expr.name] + elif isinstance(dim_expr, LoopIR.Const): + return dim_expr.val + elif isinstance(dim_expr, LoopIR.USub): + return -eval_tensor_dimension(dim_expr.arg, arg_values) + elif isinstance(dim_expr, LoopIR.BinOp): + lhs, rhs = eval_tensor_dimension( + dim_expr.lhs, arg_values + ), eval_tensor_dimension(dim_expr.rhs, arg_values) + if dim_expr.op == "+": + return lhs + rhs + elif dim_expr.op == "-": + return lhs - rhs + elif dim_expr.op == "*": + return lhs * rhs + elif dim_expr.op == "/": + if isinstance(lhs, int) and isinstance(rhs, int): + # this is what was here before and without the rhs check + # counter example of why this is wrong -3 / 2 == -1 in C and 0 in this impl + # return (lhs + rhs - 1) // rhs + return int(lhs / rhs) + else: + return lhs / rhs + elif dim_expr.op == "%": + return lhs % rhs + else: + assert False, "unexpected binop in tensor dimension" + else: + assert False, "unexpected expression type in tensor dimension" + + +CONTROL_VAL_BOUND = 128 +SEARCH_LIMIT = 10 +INT_BOUND = 128 +FLOAT_BOUND = 32 + + +def collect_path_constraints( + cursor: Union[Block, Node], cm: ConstraintMaker, type_map: dict[Sym, LoopIR.type] +) -> DisjointConstraint: + if isinstance(cursor, Block): + if len(cursor) > 0: + cursor = cursor[0] + else: + cursor = cursor._anchor + assert isinstance(cursor, Node) + if len(cursor._path) == 0: + return TRUE_CONSTRAINT + last_attr, last_index = cursor._path[-1] + cur = cursor.parent() + result = TRUE_CONSTRAINT + var_renaming = {} + while cur.depth() != 0: + if isinstance(cur, Node): + if isinstance(cur._node, LoopIR.For): + modified_variable_visitor = ModifiedVariableVisitor(type_map) + for stmt in cur._node.body: + modified_variable_visitor.visit(stmt) + for var_sym in modified_variable_visitor.modified_vars: + var_renaming[var_sym] = cm.copy_var(var_sym) + result = result.intersect( + cm.make_constraint_from_inequality( + cur._node.iter, cur._node.lo, ">=", var_renaming + ).lift_to_disjoint_constraint() + ) + result = result.intersect( + cm.make_constraint_from_inequality( + cur._node.iter, cur._node.hi, "<", var_renaming + ).lift_to_disjoint_constraint() + ) + elif isinstance(cur._node, LoopIR.If): + assert last_index is not None + modified_variable_visitor = ModifiedVariableVisitor(type_map) + for stmt, _ in zip(getattr(cur._node, last_attr), range(last_index)): + modified_variable_visitor.visit(stmt) + for var_sym in modified_variable_visitor.modified_vars: + var_renaming[var_sym] = cm.copy_var(var_sym) + constraint = cm.make_constraint(cur._node.cond, var_renaming) + if last_attr == "orelse": + result = result.intersect(constraint.invert()) + else: + result = result.intersect(constraint) + + last_attr, last_index = cur._path[-1] + cur = cur.parent() + + assert last_index is not None + modified_variable_visitor = ModifiedVariableVisitor(type_map) + for stmt, _ in zip(cur._node.body, range(last_index)): + modified_variable_visitor.visit(stmt) + for var_sym in modified_variable_visitor.modified_vars: + var_renaming[var_sym] = cm.copy_var(var_sym) + for pred in cur._node.preds: + result = result.intersect(cm.make_constraint(pred, var_renaming)) + return result + + +@dataclass +class TestCase: + arg_values: dict[Sym, Union[int, bool, float, Tensor]] + ctxt: dict[tuple[Config, str], Union[int, bool, float, Tensor]] + + +def generate_control_value(var_type: LoopIR.type) -> Union[int, bool, float]: + if isinstance(var_type, T.Bool): + return np.random.rand() < 0.5 + elif isinstance(var_type, (T.Size, T.Stride)): + return np.random.randint(1, CONTROL_VAL_BOUND) + elif isinstance(var_type, (T.Int, T.Index)): + return np.random.randint(-CONTROL_VAL_BOUND, CONTROL_VAL_BOUND) + else: + assert False, "not a control type" + + +def generate_numeric_value(var_type: LoopIR.type, shape: tuple[int, ...]) -> Tensor: + if isinstance(var_type, (T.F32, T.Num)): + dtype = np.float32 + elif isinstance(var_type, T.F16): + dtype = np.float16 + elif isinstance(var_type, T.F64): + dtype = np.float64 + elif isinstance(var_type, T.INT8): + dtype = np.int8 + elif isinstance(var_type, T.INT32): + dtype = np.int32 + elif isinstance(var_type, T.UINT8): + dtype = np.uint8 + elif isinstance(var_type, T.UINT16): + dtype = np.uint16 + else: + assert False, "not a numeric type" + + if dtype in [np.int8, np.int32]: + data = np.random.randint(-INT_BOUND, INT_BOUND, shape, dtype=dtype) + elif dtype in [np.uint8, np.uint16]: + data = np.random.randint(0, INT_BOUND, shape, dtype=dtype) + elif dtype in [np.float16, np.float32, np.float64]: + data = ((np.random.rand(*shape) * 2 - 1) * FLOAT_BOUND).astype(dtype) + else: + assert False, "unreachable" + + return Tensor( + data.flatten(), + tuple( + Dimension(dim_size, dim_stride / data.dtype.itemsize) + for dim_size, dim_stride in zip(data.shape, data.strides) + ), + ) + + +def generate_test_case( + arg_types: dict[Sym, LoopIR.type], + config_fields: frozenset[tuple[Config, str]], + constraint: DisjointConstraint, + coverage_skeleton: CoverageSkeleton, + cm: ConstraintMaker, +) -> Optional[TestCase]: + ctxt = {} + arg_values = {} + solution = coverage_skeleton.solve_constraint_with_coverage( + cm, constraint, bound=INT_BOUND, search_limit=SEARCH_LIMIT + ) + if solution is None: + return None + for config, field in config_fields: + if (config, field) in solution.ctxt: + ctxt[(config, field)] = solution.ctxt[(config, field)] + else: + field_type = config.lookup_type(field) + if field_type.is_numeric(): + val = generate_numeric_value(field_type, (1,)) + else: + val = generate_control_value(field_type) + ctxt[(config, field)] = val + + for arg_name, arg_type in arg_types.items(): + if not arg_type.is_numeric(): + if arg_name in solution.var_assignments: + if isinstance(arg_type, T.Bool): + val = solution.var_assignments[arg_name] != 0 + else: + val = solution.var_assignments[arg_name] + else: + val = generate_control_value(arg_type) + arg_values[arg_name] = val + + for arg_name, arg_type in arg_types.items(): + if arg_type.is_numeric() and not isinstance(arg_type, LoopIR.WindowType): + if arg_type.is_real_scalar(): + shape = (1,) + else: + shape = tuple( + eval_tensor_dimension(dim_expr, arg_values) + for dim_expr in arg_type.shape() + ) + arg_values[arg_name] = generate_numeric_value(arg_type.basetype(), shape) + + return TestCase(arg_values, ctxt) + + +@dataclass +class TestResult: + buffer_values: dict[Sym, np.ndarray] + ctxt_object: dict[str, Union[int, float]] + # map from JavaScript name of variable tracking coverage of different parts of coverage skeleton + # e.g. bool variable tracking whether branch of if statement gets executed + coverage_result: Optional[dict[str, Union[bool, memoryview]]] + + +def run_test_case( + test_case: TestCase, + transpiled_proc: Transpiler, +) -> Union[TestResult, Literal["failed"]]: + subs = {} + for arg_name, arg_value in test_case.arg_values.items(): + if isinstance(arg_value, Tensor): + for dim_idx, dim in enumerate(arg_value.dims): + subs[transpiled_proc.get_size_param_name(arg_name, dim_idx)] = str( + dim.size + ) + subs[transpiled_proc.get_stride_param_name(arg_name, dim_idx)] = str( + dim.stride + ) + elif isinstance(arg_value, bool): + subs[repr(arg_name)] = "true" if arg_value else "false" + elif isinstance(arg_value, (int, float)): + subs[repr(arg_name)] = str(arg_value) + else: + assert False + for (config, field), config_value in test_case.ctxt.items(): + if isinstance(config_value, Tensor): + assert config_value.data.shape == (1,) + subs[transpiled_proc.get_config_param_name(config, field)] = str( + config_value.data[0] + ) + elif isinstance(config_value, bool): + subs[transpiled_proc.get_config_param_name(config, field)] = ( + "true" if config_value else "false" + ) + elif isinstance(config_value, (int, float)): + subs[transpiled_proc.get_config_param_name(config, field)] = str( + config_value + ) + else: + assert False + + buffer_args = tuple( + test_case.arg_values[buffer_name].data.copy() + for buffer_name in transpiled_proc.get_buffer_arg_order() + ) + javascript = transpiled_proc.get_javascript_template().substitute(subs) + try: + eval_info = js_eval(javascript)(*buffer_args) + except Exception as e: + raise Exception( + f"javascript:\n{javascript}\nproc:\n{transpiled_proc.get_proc()}" + ) from e + if transpiled_proc.get_coverage_skeleton() is None: + [result, ctxt_object] = eval_info + coverage_result = None + else: + [result, ctxt_object, coverage_result] = eval_info + if result != 0: + return "failed" + return TestResult( + { + buffer_name: buffer_value + for buffer_name, buffer_value in zip( + transpiled_proc.get_buffer_arg_order(), buffer_args + ) + }, + ctxt_object, + coverage_result, + ) + + +@dataclass +class TestSpec: + proc: LoopIR.proc + constraint: DisjointConstraint + arg_types: dict[Sym, LoopIR.type] + original_scope: Block + var_renaming: dict[Sym, Sym] + + def forward_to_test(self, cursor: Block) -> Optional[Block]: + if cursor in self.original_scope: + return Block( + self.proc, + Node(self.proc, []), + "body", + range( + cursor._range.start - self.original_scope._range.start, + cursor._range.stop - self.original_scope._range.stop, + ), + ) + for node_idx, node in enumerate(self.original_scope): + if node.is_ancestor_of(cursor): + return Block( + self.proc, + Node( + self.proc, + [("body", node_idx)] + cursor._anchor._path[len(node._path) :], + ), + cursor._attr, + cursor._range, + ) + return None + + def forward_staging_args( + self, staging_args: Optional[StageMemArgs] + ) -> Optional[StageMemArgs]: + if staging_args is None: + return None + forwarded_scope = self.forward_to_test(staging_args.scope) + if forwarded_scope is None: + return None + return StageMemArgs( + staging_args.buffer_sym, staging_args.staged_window_expr, forwarded_scope + ) + + def backward_from_test(self, path: NodePath) -> NodePath: + assert path.path[0][1] is not None + return NodePath( + tuple(self.original_scope._anchor._path) + + ( + ( + self.original_scope._attr, + self.original_scope._range.start + path.path[0][1], + ), + ) + + path.path[1:] + ) + + +@dataclass +class TestScope: + scope: Block + + def broaden(self) -> Optional["TestScope"]: + if self.scope._anchor.depth() == 0: + new_scope = self.scope.expand() + if ( + new_scope._range.start == self.scope._range.start + and new_scope._range.stop == self.scope._range.stop + ): + return None + else: + return TestScope(new_scope) + else: + return TestScope(self.scope._anchor.as_block()) + + def get_type_map(self) -> dict[Sym, LoopIR.type]: + root_proc = self.scope.get_root() + proc_type_visitor = TypeVisitor() + proc_type_visitor.visit(root_proc) + return proc_type_visitor.type_map + + def get_test_spec( + self, cm: ConstraintMaker, type_map: dict[Sym, LoopIR.type] + ) -> TestSpec: + root_proc = self.scope.get_root() + proc_type_visitor = TypeVisitor() + proc_type_visitor.visit(root_proc) + + constraint = TRUE_CONSTRAINT + constraint = constraint.intersect( + collect_path_constraints(self.scope, cm, type_map) + ) + args = [ + LoopIR.fnarg( + name=var, + type=arg_type, + mem=DRAM if arg_mem is None else arg_mem, + srcinfo=SrcInfo("", 0), + ) + for var, (arg_type, arg_mem) in get_free_variables( + proc_type_visitor.type_map, + proc_type_visitor.mem_map, + self.scope, + ).items() + ] + args = sorted( + args, + key=lambda arg: ( + (2 if isinstance(arg.type, LoopIR.WindowType) else 1) + if arg.type.is_numeric() + else 0 + ), + ) + + proc = LoopIR.proc( + name=root_proc.name, + args=args, + preds=[], + body=(self.scope.resolve_all()), + instr=None, + srcinfo=root_proc.srcinfo, + ) + modified_variable_visitor = ModifiedVariableVisitor(type_map) + modified_variable_visitor.visit(proc) + arg_types = {arg.name: arg.type for arg in args} + return TestSpec( + proc, + constraint, + arg_types, + self.scope, + {sym: cm.top_var(sym) for sym in modified_variable_visitor.modified_vars}, + ) + + +TEST_CASE_BOUND = 15 +MAX_SKIPPED_TESTS = 3 +MAX_ITERS = 20 + + +@dataclass +class StageMemResult: + overlapping_accesses: frozenset[NodePath] + dim_needs_upper_bound_guard: tuple[bool, ...] + dim_needs_lower_bound_guard: tuple[bool, ...] + + +@dataclass +class CoverageSummary: + stage_mem_result: StageMemResult + + +def fuzz( + scope1: Block, + scope2: Block, +): + """ + fuzz to determine if behavior of two programs are equivalent + scope1: smallest scope containing all changes made by scheduling op in original program + scope2: scope corresponding to starting scope in transformed program + """ + cur_scope1 = TestScope(scope1) + cur_scope2 = TestScope(scope2) + cur_type_map1 = cur_scope1.get_type_map() + cur_type_map2 = cur_scope2.get_type_map() + + while cur_scope1 is not None: + assert cur_scope2 is not None + cm = ConstraintMaker(cur_type_map1 | cur_type_map2) + + spec1 = cur_scope1.get_test_spec(cm, cur_type_map1) + spec2 = cur_scope2.get_test_spec(cm, cur_type_map2) + + transpiled_test1 = Transpiler( + # new proc that contains the current scope as a body, not the entire proc + spec1.proc, + CoverageArgs( + cm, + spec1.var_renaming, + spec1.forward_to_test(scope1), + ), + ) + transpiled_test2 = Transpiler( + spec2.proc, + CoverageArgs( + cm, + spec2.var_renaming, + spec2.forward_to_test(scope2), + ), + ) + + config_fields = transpiled_test1.get_configs() | transpiled_test2.get_configs() + + arg_types = spec1.arg_types | spec2.arg_types + # precondition of current scope in both original and transformed program + constraint = spec1.constraint.union(spec2.constraint) + skeleton1, skeleton2 = ( + transpiled_test1.get_coverage_skeleton(), + transpiled_test2.get_coverage_skeleton(), + ) + assert skeleton1 is not None and skeleton2 is not None + # symbolic representation of control flow in both original and transformed scope + coverage_skeleton = skeleton1.merge(skeleton2) + tests_passed = True + skipped_tests = 0 + iters = 0 + while ( + not coverage_skeleton.get_coverage_progress().is_finished() + and iters < MAX_ITERS + and tests_passed + ): + test_case = generate_test_case( + arg_types, + config_fields, + constraint, + coverage_skeleton, + cm, + ) + # if constraint is unsolvable + if test_case is None: + skipped_tests += 1 + if skipped_tests > MAX_SKIPPED_TESTS: + # program should pass but not testing it is probably bad + assert False + else: + continue + + out1 = run_test_case(test_case, transpiled_test1) + out2 = run_test_case(test_case, transpiled_test2) + if out1 == "failed" or out2 == "failed": + # precondition in called subproc failed or out of bounds access + tests_passed = False + break + assert out1.coverage_result is not None and out2.coverage_result is not None + coverage_skeleton.update_coverage( + out1.coverage_result | out2.coverage_result + ) + for buffer_name in out1.buffer_values.keys() & out2.buffer_values.keys(): + if not np.allclose( + out1.buffer_values[buffer_name], out2.buffer_values[buffer_name] + ): + tests_passed = False + break + if cur_scope1.broaden() is not None: + for ctxt_name in out1.ctxt_object.keys() & out2.ctxt_object.keys(): + if not np.allclose( + out1.ctxt_object[ctxt_name], out2.ctxt_object[ctxt_name] + ): + tests_passed = False + break + iters += 1 + if tests_passed: + return + else: + cur_scope1 = cur_scope1.broaden() + cur_scope2 = cur_scope2.broaden() + raise SchedulingError("tests failed at broadest scope") + + +def fuzz_single_scope( + scope: Block, staging_args: Optional[StageMemArgs] = None +) -> CoverageSummary: + """ + fuzz program to determine properties besides program equivalence + e.g. stage_mem which determines whether accesses can overlap or parallel loop checks + scope: smallest scope containing properties that need to be fuzzed for + staging_args: arguments to stage_mem scheduling op + """ + cur_scope = TestScope(scope) + cur_type_map = cur_scope.get_type_map() + + while cur_scope is not None: + cm = ConstraintMaker(cur_type_map) + + spec = cur_scope.get_test_spec(cm, cur_type_map) + + transpiled_test = Transpiler( + # new proc that contains the current scope as a body, not the entire proc + spec.proc, + CoverageArgs( + cm, + spec.var_renaming, + spec.forward_to_test(scope), + spec.forward_staging_args(staging_args), + ), + ) + + config_fields = transpiled_test.get_configs() + + arg_types = spec.arg_types + # precondition of current scope in both original and transformed program + constraint = spec.constraint + + # symbolic representation of control flow in both original and transformed scope + coverage_skeleton = transpiled_test.get_coverage_skeleton() + assert coverage_skeleton is not None + tests_passed = True + skipped_tests = 0 + iters = 0 + while ( + not coverage_skeleton.get_coverage_progress().is_finished() + and iters < MAX_ITERS + and tests_passed + ): + test_case = generate_test_case( + arg_types, + config_fields, + constraint, + coverage_skeleton, + cm, + ) + # if constraint is unsolvable + if test_case is None: + skipped_tests += 1 + if skipped_tests > MAX_SKIPPED_TESTS: + # program should pass but not testing it is probably bad + assert False + else: + continue + + out = run_test_case(test_case, transpiled_test) + if out == "failed": + # precondition in called subproc failed or out of bounds access + tests_passed = False + break + assert out.coverage_result is not None + coverage_skeleton.update_coverage(out.coverage_result) + iters += 1 + if tests_passed: + overlapping_accesses = [] + invalid_staging = False + for staging_overlap in coverage_skeleton.staging_overlaps: + if staging_overlap.has_overlap: + if staging_overlap.has_disjoint_access: + invalid_staging = True + break + else: + overlapping_accesses.append( + spec.backward_from_test(staging_overlap.access_cursor) + ) + if invalid_staging: + cur_scope = cur_scope.broaden() + continue + else: + return CoverageSummary( + StageMemResult( + frozenset(overlapping_accesses), + tuple( + staging_bound_check.violated_upper_bound + for staging_bound_check in coverage_skeleton.staging_bound_checks + ), + tuple( + staging_bound_check.violated_lower_bound + for staging_bound_check in coverage_skeleton.staging_bound_checks + ), + ) + ) + else: + cur_scope = cur_scope.broaden() + raise SchedulingError("tests failed at broadest scope") diff --git a/src/exo/rewrite/chexo/constraint_solver.py b/src/exo/rewrite/chexo/constraint_solver.py new file mode 100644 index 000000000..971cdf0d7 --- /dev/null +++ b/src/exo/rewrite/chexo/constraint_solver.py @@ -0,0 +1,1013 @@ +from dataclasses import dataclass +from typing import Callable, Literal, Union, Optional + +from ...core.configs import Config +from ...core.prelude import Sym +from ...core.LoopIR import LoopIR, T +from ...core.extern import Extern +import numpy as np +from scipy.optimize import linprog +from hsnf import smith_normal_form +import textwrap + + +@dataclass +class IndexTerm: + buffer_sym: Sym + indices: tuple["Expression", ...] + register_new_index: Callable[[tuple[int, ...]], Sym] + + def substitute(self, assignments: dict[Sym, int]) -> Union["IndexTerm", Sym]: + new_indices = tuple(index.substitute(assignments) for index in self.indices) + int_indices = [] + trivial = True + for new_index in new_indices: + trivial_val = new_index.get_trivial_result() + if trivial_val is None: + trivial = False + break + else: + int_indices.append(trivial_val) + if trivial: + return self.register_new_index(tuple(int_indices)) + return IndexTerm(self.buffer_sym, new_indices, self.register_new_index) + + def collect_nonlinear_syms(self) -> frozenset[Sym]: + return frozenset().union(*(index.collect_syms() for index in self.indices)) + + def collect_syms(self) -> frozenset[Sym]: + return self.collect_nonlinear_syms() + + def pretty_print(self) -> str: + index_str = ",".join(index.pretty_print() for index in self.indices) + return f"{str(self.buffer_sym)}[{index_str}]" + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "IndexTerm": + return IndexTerm( + self.buffer_sym, + tuple(index.rename_syms(lookup) for index in self.indices), + self.register_new_index, + ) + + +@dataclass +class ExternTerm: + extern: Extern + args: tuple["Expression", ...] + + def substitute(self, assignments: dict[Sym, int]) -> Union["ExternTerm", int]: + new_args = tuple(arg.substitute(assignments) for arg in self.args) + int_args = [] + trivial = True + for new_arg in new_args: + trivial_val = new_arg.get_trivial_result() + if trivial_val is None: + trivial = False + break + else: + int_args.append(trivial_val) + if trivial: + return self.extern.interpret(tuple(int_args)) + return ExternTerm(self.extern, new_args) + + def collect_nonlinear_syms(self) -> frozenset[Sym]: + return frozenset().union(*(arg.collect_syms() for arg in self.args)) + + def collect_syms(self) -> frozenset[Sym]: + return self.collect_nonlinear_syms() + + def pretty_print(self) -> str: + arg_str = ",".join(arg.pretty_print() for arg in self.args) + return f"{str(self.extern)}({arg_str})" + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "ExternTerm": + return ExternTerm( + self.extern, + tuple(arg.rename_syms(lookup) for arg in self.args), + ) + + +FunctionTerm = Union[IndexTerm, ExternTerm] + + +@dataclass +class ConstraintTerm: + coefficient: int + syms: tuple[Sym, ...] + functions: tuple[FunctionTerm, ...] + + def negate(self) -> "ConstraintTerm": + return ConstraintTerm(-self.coefficient, self.syms, self.functions) + + def multiply(self, other) -> "ConstraintTerm": + return ConstraintTerm( + self.coefficient * other.coefficient, + self.syms + other.syms, + self.functions + other.functions, + ) + + def substitute(self, assignments: dict[Sym, int]) -> "ConstraintTerm": + new_syms = [] + new_coefficient = self.coefficient + for sym in self.syms: + if sym in assignments: + new_coefficient *= assignments[sym] + else: + new_syms.append(sym) + new_functions = [] + for function in self.functions: + sub = function.substitute(assignments) + if isinstance(sub, int): + new_coefficient *= sub + elif isinstance(sub, Sym): + new_syms.append(sub) + else: + new_functions.append(sub) + return ConstraintTerm( + new_coefficient, + tuple(new_syms), + tuple(new_functions), + ) + + def collect_nonlinear_syms(self) -> frozenset[Sym]: + occurrences = set() + result = set() + for sym in self.syms: + if sym in occurrences: + result.add(sym) + else: + occurrences.add(sym) + return frozenset(result) + + def collect_syms(self) -> frozenset[Sym]: + return frozenset(self.syms).union( + *(function.collect_syms() for function in self.functions) + ) + + def pretty_print(self) -> str: + return f"{' * '.join([str(self.coefficient)] + [str(sym) for sym in self.syms] + [function.pretty_print() for function in self.functions])}" + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "ConstraintTerm": + return ConstraintTerm( + self.coefficient, + tuple(lookup[sym] if sym in lookup else sym for sym in self.syms), + tuple(function.rename_syms(lookup) for function in self.functions), + ) + + +@dataclass +class LinearConstraint: + coefficients: dict[Sym, int] + offset: int + has_slack: bool + + def get_trivial_result(self) -> Optional[bool]: + if len(self.coefficients) > 0: + return None + return (self.offset >= 0 and self.has_slack) or self.offset == 0 + + +@dataclass +class Expression: + terms: Optional[tuple[ConstraintTerm, ...]] + + @staticmethod + def from_constant(const: int) -> "Expression": + return Expression((ConstraintTerm(const, (), ()),)) + + @staticmethod + def from_sym(sym: Sym) -> "Expression": + return Expression((ConstraintTerm(1, (sym,), ()),)) + + @staticmethod + def unsolvable() -> "Expression": + return Expression(None) + + @staticmethod + def from_function(function_term: FunctionTerm) -> "Expression": + return Expression((ConstraintTerm(1, (), (function_term,)),)) + + def negate(self) -> "Expression": + return Expression( + None if self.terms is None else tuple(term.negate() for term in self.terms) + ) + + def add(self, other: "Expression") -> "Expression": + return Expression( + None + if self.terms is None or other.terms is None + else (*self.terms, *other.terms) + ) + + def multiply(self, other: "Expression") -> "Expression": + return Expression( + None + if self.terms is None or other.terms is None + else tuple( + term1.multiply(term2) for term1 in self.terms for term2 in other.terms + ) + ) + + def substitute(self, assignments: dict[Sym, int]) -> "Expression": + if self.terms is None: + return self + coefficients: dict[tuple[Sym, ...], int] = {} + other_terms: list[ConstraintTerm] = [] + for term in self.terms: + sub_term = term.substitute(assignments) + if len(sub_term.functions) != 0: + other_terms.append(sub_term) + else: + if sub_term.syms not in coefficients: + coefficients[sub_term.syms] = 0 + coefficients[sub_term.syms] += sub_term.coefficient + return Expression( + tuple( + ConstraintTerm(coefficient, syms, ()) + for syms, coefficient in coefficients.items() + ) + + tuple(other_terms) + ) + + def get_trivial_result(self) -> Optional[int]: + if self.terms is None: + return None + if len(self.terms) == 0: + return 0 + elif len(self.terms) == 1 and len(self.terms[0].syms) == 0: + return self.terms[0].coefficient + return None + + def collect_syms(self) -> frozenset[Sym]: + if self.terms is None: + return frozenset() + return frozenset().union(*(term.collect_syms() for term in self.terms)) + + def collect_nonlinear_syms(self) -> frozenset[Sym]: + if self.terms is None: + return frozenset() + return frozenset().union( + *[term.collect_nonlinear_syms() for term in self.terms] + ) + + def pretty_print(self): + if self.terms is None: + return "unsolvable" + return " + ".join([term.pretty_print() for term in self.terms]) + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "Expression": + if self.terms is None: + return self + return Expression(tuple(term.rename_syms(lookup) for term in self.terms)) + + +@dataclass +class Constraint: + lhs: Expression + has_slack: bool + + def linearize(self, assignments: dict[Sym, int]) -> Optional[LinearConstraint]: + new_lhs = self.lhs.substitute(assignments) + if new_lhs.terms is None: + return None + offset = 0 + coefficients = {} + for term in new_lhs.terms: + if len(term.functions) != 0: + return None + elif len(term.syms) == 0: + offset += term.coefficient + elif len(term.syms) == 1: + coefficients[term.syms[0]] = term.coefficient + else: + return None + return LinearConstraint(coefficients, offset, self.has_slack) + + def collect_syms(self) -> frozenset[Sym]: + return self.lhs.collect_syms() + + def collect_nonlinear_syms(self) -> frozenset[Sym]: + return self.lhs.collect_nonlinear_syms() + + def lift_to_disjoint_constraint(self) -> "DisjointConstraint": + return DisjointConstraint((ConstraintClause((self,)),)) + + def invert(self) -> "DisjointConstraint": + if self.has_slack: + return Constraint( + self.lhs.negate().add(Expression.from_constant(-1)), + True, + ).lift_to_disjoint_constraint() + else: + return DisjointConstraint( + ( + ConstraintClause( + ( + Constraint( + self.lhs.add(Expression.from_constant(-1)), + True, + ), + ) + ), + ConstraintClause( + ( + Constraint( + self.lhs.negate().add(Expression.from_constant(-1)), + True, + ), + ) + ), + ) + ) + + def pretty_print(self) -> str: + return f"{self.lhs.pretty_print()} {'>=' if self.has_slack else '=='} 0" + + def substitute(self, assignments: dict[Sym, int]) -> "Constraint": + return Constraint(self.lhs.substitute(assignments), self.has_slack) + + def get_trivial_result(self) -> Optional[bool]: + lhs_result = self.lhs.get_trivial_result() + if lhs_result is not None: + return (lhs_result >= 0 and self.has_slack) or lhs_result == 0 + return None + + def is_unsolvable(self) -> bool: + return self.lhs.terms is None + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "Constraint": + return Constraint(self.lhs.rename_syms(lookup), self.has_slack) + + +@dataclass +class ConstraintClause: + constraints: tuple[Constraint, ...] + + def invert(self) -> "DisjointConstraint": + acc = FALSE_CONSTRAINT + for constraint in self.constraints: + acc = acc.union(constraint.invert()) + return acc + + def pretty_print(self) -> str: + lines = [ + "intersect(", + *list( + textwrap.indent(constraint.pretty_print(), "\t") + "," + for constraint in self.constraints + ), + ")", + ] + return "\n".join(lines) + + def collect_syms(self) -> frozenset[Sym]: + return frozenset().union( + *(constraint.collect_syms() for constraint in self.constraints) + ) + + def substitute(self, assignments: dict[Sym, int]) -> "ConstraintClause": + new_constraints = [] + for constraint in self.constraints: + new_constraint = constraint.substitute(assignments) + trivial_result = new_constraint.get_trivial_result() + if trivial_result is None: + new_constraints.append(new_constraint) + elif not trivial_result: + return ConstraintClause((new_constraint,)) + return ConstraintClause(tuple(new_constraints)) + + def get_trivial_result(self) -> Optional[bool]: + if len(self.constraints) == 0: + return True + elif len(self.constraints) == 1: + return self.constraints[0].get_trivial_result() + return None + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "ConstraintClause": + return ConstraintClause( + tuple(constraint.rename_syms(lookup) for constraint in self.constraints) + ) + + +MAX_CLAUSES = 16 + + +@dataclass +class DisjointConstraint: + clauses: tuple[ConstraintClause, ...] + + def intersect(self, other: "DisjointConstraint"): + new_clauses = tuple( + ConstraintClause(lhs_clause.constraints + rhs_clause.constraints) + for lhs_clause in self.clauses + for rhs_clause in other.clauses + ) + if len(new_clauses) > MAX_CLAUSES: + new_clauses = tuple( + np.random.choice(new_clauses, MAX_CLAUSES, replace=False) + ) + return DisjointConstraint(new_clauses) + + def union(self, other: "DisjointConstraint"): + new_clauses = self.clauses + other.clauses + if len(new_clauses) > MAX_CLAUSES: + new_clauses = tuple( + np.random.choice(new_clauses, MAX_CLAUSES, replace=False) + ) + return DisjointConstraint(new_clauses) + + def invert(self) -> "DisjointConstraint": + acc = TRUE_CONSTRAINT + for clause in self.clauses: + acc = acc.intersect(clause.invert()) + return acc + + def pretty_print(self) -> str: + lines = [ + "union(", + *list( + textwrap.indent(clause.pretty_print(), "\t") + "," + for clause in self.clauses + ), + ")", + ] + return "\n".join(lines) + + def collect_syms(self) -> frozenset[Sym]: + return frozenset().union(*(clause.collect_syms() for clause in self.clauses)) + + def substitute(self, assignments: dict[Sym, int]) -> "DisjointConstraint": + new_clauses = [] + for clause in self.clauses: + new_clause = clause.substitute(assignments) + trivial_result = new_clause.get_trivial_result() + if trivial_result is None: + new_clauses.append(new_clause) + elif trivial_result: + return DisjointConstraint((new_clause,)) + return DisjointConstraint(tuple(new_clauses)) + + def get_trivial_result(self) -> Optional[bool]: + if len(self.clauses) == 0: + return False + elif len(self.clauses) == 1: + return self.clauses[0].get_trivial_result() + return None + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "DisjointConstraint": + return DisjointConstraint( + tuple(clause.rename_syms(lookup) for clause in self.clauses) + ) + + +TRUE_CONSTRAINT = DisjointConstraint((ConstraintClause(()),)) +FALSE_CONSTRAINT = DisjointConstraint(()) + + +@dataclass +class Solution: + ctxt: dict[tuple[Config, str], int] + var_assignments: dict[Sym, int] + buffer_assignments: dict[tuple[Sym, tuple[int, ...]], int] + substitutions: dict[Sym, int] + + +SIMUL_CONSTRAINT_LIMIT = 32 + + +class ConstraintMaker: + def __init__(self, type_map: dict[Sym, LoopIR.type]): + self.var_subs: dict[Sym, Expression] = {} + self.ctxt: dict[tuple[Config, str], Expression] = {} + self.extra_constraints: dict[Sym, DisjointConstraint] = {} + self.stride_dummies: dict[tuple[Sym, int], Sym] = {} + self.buffer_syms: dict[Sym, tuple[Sym, tuple[int, ...]]] = {} + self.unbound_buffers: set[Sym] = set() + self.type_map = type_map + for sym, sym_type in type_map.items(): + var_sub_result = self.make_var_sub(sym.name(), sym_type) + if var_sub_result is not None: + self.var_subs[sym] = var_sub_result + + def copy_var(self, sym: Sym) -> Sym: + new_sym = Sym(f"{sym.name()}_copy") + var_sub_result = self.make_var_sub(new_sym.name(), self.type_map[sym]) + if var_sub_result is not None: + self.var_subs[new_sym] = var_sub_result + return new_sym + + def top_var(self, sym: Sym) -> Sym: + new_sym = Sym(f"{sym.name()}_top") + sym_type = self.type_map[sym] + if sym_type.is_tensor_or_window(): + self.unbound_buffers.add(new_sym) + else: + self.var_subs[new_sym] = Expression.unsolvable() + return new_sym + + def get_var_sub(self, var_sym: Sym) -> Expression: + return self.var_subs[var_sym] + + def make_var_sub(self, name: str, var_type: LoopIR.type) -> Optional[Expression]: + if isinstance(var_type, (T.Size, T.Stride)): + # positive variable + return Expression.from_sym(Sym(f"{name}_m1")).add( + Expression.from_constant(1) + ) + elif isinstance(var_type, (T.Int, T.Index, T.INT32, T.INT8, T.UINT16, T.UINT8)): + # unsigned variables are represented as a - b, where a and b are nonnegative + a, b = Sym(f"{name}_a"), Sym(f"{name}_b") + return Expression.from_sym(a).add(Expression.from_sym(b).negate()) + elif isinstance(var_type, T.Bool): + # constrained to [0, 1] + sym = Sym(name) + self.extra_constraints[sym] = Constraint( + Expression.from_sym(sym).negate().add(Expression.from_constant(1)), + True, + ).lift_to_disjoint_constraint() + return Expression.from_sym(sym) + else: + return None + + def register_buffer_index(self, indices: tuple[int, ...], buffer_sym: Sym) -> Sym: + sym = Sym("buf") + self.buffer_syms[sym] = (buffer_sym, indices) + return sym + + def make_expression( + self, expr: Union[LoopIR.expr, Sym], var_renaming: dict[Sym, Sym] + ) -> Expression: + # expect that expr is int type + if isinstance(expr, Sym): + if expr in var_renaming: + return self.var_subs[var_renaming[expr]] + return self.var_subs[expr] + elif isinstance(expr, LoopIR.Read): + if len(expr.idx) == 0: + return self.var_subs[expr.name] + else: + buf_type = self.type_map[expr.name] + if isinstance(buf_type, LoopIR.Tensor): + buf_name = expr.name + index_exprs = tuple( + self.make_expression(idx, var_renaming) for idx in expr.idx + ) + elif isinstance(buf_type, LoopIR.WindowType): + buf_name = buf_type.src_buf + index_list: list[Expression] = [] + expr_idx_iter = iter(expr.idx) + for idx in buf_type.idx: + if isinstance(idx, LoopIR.Point): + index_list.append( + self.make_expression(idx.pt, var_renaming) + ) + elif isinstance(idx, LoopIR.Interval): + index_list.append( + self.make_expression(idx.lo, var_renaming).add( + self.make_expression( + next(expr_idx_iter), var_renaming + ) + ) + ) + else: + assert False, "unexpected window access" + index_exprs = tuple(index_list) + else: + assert False, "unexpected buffer type" + if buf_name in var_renaming: + buf_name = var_renaming[buf_name] + if buf_name in self.unbound_buffers: + return Expression.unsolvable() + return Expression.from_function( + IndexTerm( + buf_name, + index_exprs, + lambda indices: self.register_buffer_index(indices, buf_name), + ) + ) + elif isinstance(expr, LoopIR.Const): + return Expression.from_constant(expr.val) + elif isinstance(expr, LoopIR.USub): + return self.make_expression(expr.arg, var_renaming).negate() + elif isinstance(expr, LoopIR.BinOp): + # TODO: support mod and div using extra variables + lhs = self.make_expression(expr.lhs, var_renaming) + rhs = self.make_expression(expr.rhs, var_renaming) + if expr.op == "+": + return lhs.add(rhs) + elif expr.op == "-": + return lhs.add(rhs.negate()) + elif expr.op == "*": + return lhs.multiply(rhs) + elif expr.op in ["/", "%"]: + div, rem = Sym("div"), Sym("rem") + visible_sym = rem if expr.op == "%" else div + self.extra_constraints[visible_sym] = ( + Constraint( + lhs.negate() + .add(Expression.from_sym(rem)) + .add(rhs.multiply(Expression.from_sym(div))), + False, + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + Expression.from_sym(rem) + .add(Expression.from_constant(1)) + .negate() + .add(rhs), + True, + ).lift_to_disjoint_constraint() + ) + ) + return Expression.from_sym(visible_sym) + else: + assert False, f"unsupported op in assertion: {expr.op}" + elif isinstance(expr, LoopIR.Extern): + extern_args = tuple( + self.make_expression(arg, var_renaming) for arg in expr.args + ) + extern: Extern = expr.f + extern_result_sym = Sym("ext") + try: + extern_constraint = extern.express_in_constraints( + extern_args, extern_result_sym + ) + self.extra_constraints[extern_result_sym] = extern_constraint + return Expression.from_sym(extern_result_sym) + except NotImplementedError: + return Expression.from_function(ExternTerm(extern, extern_args)) + elif isinstance(expr, LoopIR.StrideExpr): + if (expr.name, expr.dim) not in self.stride_dummies: + new_sym = Sym("stride") + self.stride_dummies[(expr.name, expr.dim)] = new_sym + dummy = self.stride_dummies[(expr.name, expr.dim)] + return Expression.from_sym(dummy) + elif isinstance(expr, LoopIR.ReadConfig): + if (expr.config, expr.field) not in self.ctxt: + field_type = expr.config.lookup_type(expr.field) + var_sub_result = self.make_var_sub( + f"{expr.config.name()}_{expr.field}", field_type + ) + assert ( + var_sub_result is not None + ), "constraints can only occur on control variables" + self.ctxt[(expr.config, expr.field)] = var_sub_result + return self.ctxt[(expr.config, expr.field)] + else: + assert False, f"unsupported expr" + + def make_constraint( + self, expr: LoopIR.expr, var_renaming: dict[Sym, Sym] + ) -> DisjointConstraint: + # expect that expr is bool type + if isinstance(expr, LoopIR.BinOp): + if expr.op == "and": + lhs_constraints, rhs_constraints = self.make_constraint( + expr.lhs, var_renaming + ), self.make_constraint(expr.rhs, var_renaming) + return lhs_constraints.intersect(rhs_constraints) + elif expr.op == "or": + lhs_constraints, rhs_constraints = self.make_constraint( + expr.lhs, var_renaming + ), self.make_constraint(expr.rhs, var_renaming) + return lhs_constraints.union(rhs_constraints) + elif expr.op == "==" and isinstance(expr.lhs.type, LoopIR.Bool): + lhs_constraints, rhs_constraints = self.make_constraint( + expr.lhs, var_renaming + ), self.make_constraint(expr.rhs, var_renaming) + return ( + lhs_constraints.invert().intersect(rhs_constraints.invert()) + ).union(lhs_constraints.intersect(rhs_constraints)) + else: + return self.make_constraint_from_inequality( + expr.lhs, expr.rhs, expr.op, var_renaming + ).lift_to_disjoint_constraint() + elif isinstance(expr, LoopIR.Read): + assert len(expr.idx) == 0, "cannot index into boolean" + return Constraint( + Expression.from_sym(expr.name).add(Expression.from_constant(-1)), + True, + ).lift_to_disjoint_constraint() + elif isinstance(expr, LoopIR.Const): + return TRUE_CONSTRAINT if expr.val else FALSE_CONSTRAINT + elif isinstance(expr, LoopIR.ReadConfig): + if (expr.config, expr.field) not in self.ctxt: + field_type = expr.config.lookup_type(expr.field) + assert isinstance(field_type, LoopIR.Bool) + var_sub_result = self.make_var_sub( + f"{expr.config.name()}_{expr.field}", field_type + ) + assert ( + var_sub_result is not None + ), "constraints can only occur on control variables" + self.ctxt[(expr.config, expr.field)] = var_sub_result + return Constraint( + self.ctxt[(expr.config, expr.field)].add(Expression.from_constant(-1)), + False, + ).lift_to_disjoint_constraint() + else: + assert False, "only boolean expected" + + def make_constraint_from_inequality( + self, + lhs: Union[LoopIR.expr, Sym], + rhs: Union[LoopIR.expr, Sym], + op: str, + var_renaming: dict[Sym, Sym], + ) -> Constraint: + lhs_expr = self.make_expression(lhs, var_renaming) + rhs_expr = self.make_expression(rhs, var_renaming) + if op == "<": + return Constraint( + rhs_expr.add(lhs_expr.negate()).add(Expression.from_constant(-1)), True + ) + elif op == ">": + return Constraint( + lhs_expr.add(rhs_expr.negate()).add(Expression.from_constant(-1)), True + ) + elif op == "<=": + return Constraint(rhs_expr.add(lhs_expr.negate()), True) + elif op == ">=": + return Constraint(lhs_expr.add(rhs_expr.negate()), True) + elif op == "==": + return Constraint(lhs_expr.add(rhs_expr.negate()), False) + else: + assert False, "boolean ops expected" + + def _make_solution_from_assignments(self, assignments: dict[Sym, int]) -> Solution: + var_assignments = {} + for sym, sub in self.var_subs.items(): + result = sub.substitute(assignments).get_trivial_result() + if result is not None: + var_assignments[sym] = result + buffer_assignments = {} + for sym, assignment in assignments.items(): + if sym in self.buffer_syms: + buffer_assignments[self.buffer_syms[sym]] = assignment + ctxt = {} + for (config, field), sub in self.ctxt.items(): + result = sub.substitute(assignments).get_trivial_result() + if result is not None: + ctxt[(config, field)] = result + return Solution(ctxt, var_assignments, buffer_assignments, assignments) + + def _solve_for_assignments( + self, all_constraints: tuple[Constraint, ...], bound: int + ) -> Union[Literal["failed", "infeasible"], dict[Sym, int]]: + assignments = {} + self.buffer_syms = {} + while True: + linear_constraints: list[LinearConstraint] = [] + linear_constraint_syms: set[Sym] = set() + nonlinear_syms: set[Sym] = set() + nontrivial_constraint_exists = False + for constraint in all_constraints: + if constraint.is_unsolvable(): + return "infeasible" + linear_result = constraint.linearize(assignments) + if linear_result is not None: + trivial_result = linear_result.get_trivial_result() + if trivial_result == False: + return "infeasible" if len(assignments) == 0 else "failed" + elif trivial_result is None: + nontrivial_constraint_exists = True + linear_constraints.append(linear_result) + linear_constraint_syms |= { + sym for sym in linear_result.coefficients.keys() + } + else: + nontrivial_constraint_exists = True + nonlinear_syms |= constraint.substitute( + assignments + ).collect_nonlinear_syms() + if not nontrivial_constraint_exists: + break + priority_syms = nonlinear_syms & linear_constraint_syms + if len(priority_syms) == 0 and len(nonlinear_syms) != 0: + chosen_sym = np.random.choice( + sorted(list(nonlinear_syms), key=lambda sym: sym._id) + ) + assignments[chosen_sym] = np.random.randint(0, bound) + continue + sym_ordering = { + sym: i + for i, sym in enumerate( + sorted( + list(linear_constraint_syms), + key=lambda sym: sym._id, + ) + ) + } + n = len(linear_constraints) + m_nonslack = len(linear_constraint_syms) + matrix_A = np.zeros( + (n, m_nonslack), + dtype=np.int32, + ) + m = m_nonslack + vec_b = np.zeros(n, dtype=np.int32) + for row, linear_constraint in enumerate(linear_constraints): + for sym, coefficient in linear_constraint.coefficients.items(): + matrix_A[row, sym_ordering[sym]] = coefficient + if linear_constraint.has_slack: + slack_col = np.zeros((n, 1), dtype=np.int32) + slack_col[row, 0] = -1 + matrix_A = np.hstack((matrix_A, slack_col)) + m += 1 + vec_b[row] = -linear_constraint.offset + matrix_B, matrix_U, matrix_V = smith_normal_form(matrix_A) + vec_d = matrix_U @ vec_b + k = min(n, m) + vec_f = np.zeros(m) + for i in range(min(n, m)): + if matrix_B[i, i] == 0: + k = i + break + if vec_d[i] % matrix_B[i, i] != 0: + return "infeasible" if len(assignments) == 0 else "failed" + vec_f += vec_d[i] / matrix_B[i, i] * matrix_V[:, i] + if m == k: + solution = vec_f + if not np.all(vec_f >= 0): + return "infeasible" if len(assignments) == 0 else "failed" + else: + matrix_C = matrix_V[:, k:] + upper_bound_matrix = np.concatenate( + (matrix_C[:m_nonslack, :], -matrix_C), axis=0 + ) + upper_bound_offset = np.concatenate( + (np.ones(m_nonslack) * bound - vec_f[:m_nonslack], vec_f), + axis=0, + ) + radius_row = np.zeros((1, m - k + 1)) + radius_row[0, -1] = -1 + upper_bound_matrix_with_radius = np.concatenate( + ( + np.concatenate( + ( + upper_bound_matrix, + np.linalg.norm(upper_bound_matrix, axis=1)[ + :, np.newaxis + ], + ), + axis=1, + ), + radius_row, + ), + axis=0, + ) + upper_bound_offset_with_radius = np.concatenate( + (upper_bound_offset, np.array([0])), axis=0 + ) + objective = np.zeros(m - k + 1) + objective[-1] = -1 + lp = linprog( + objective, + A_ub=upper_bound_matrix_with_radius, + b_ub=upper_bound_offset_with_radius, + bounds=(None, None), + method="highs", + ) + if not lp.success: + return "infeasible" if len(assignments) == 0 else "failed" + cur_y = lp.x[: m - k] + har_iter = 50 + last_int_y = None + for _ in range(har_iter): + direction = np.random.normal(size=m - k) + direction = direction / np.linalg.norm(direction) + lower_bounds = -matrix_C @ cur_y - vec_f + upper_bounds = lower_bounds + bound + upper_bounds[m_nonslack:] = -np.nan + coefficients = matrix_C @ direction + lower_bounds = lower_bounds[coefficients != 0] + upper_bounds = upper_bounds[coefficients != 0] + coefficients = coefficients[coefficients != 0] + max_lambda = np.nanmin( + np.where(coefficients < 0, lower_bounds, upper_bounds) + / coefficients + ) + min_lambda = np.nanmax( + np.where(coefficients >= 0, lower_bounds, upper_bounds) + / coefficients + ) + new_y = cur_y + direction * ( + np.random.rand() * (max_lambda - min_lambda) + min_lambda + ) + new_int_y = np.round(new_y) + cur_y = new_y + if np.all(upper_bound_matrix @ new_int_y <= upper_bound_offset): + last_int_y = new_int_y + if last_int_y is not None: + solution = matrix_C @ last_int_y + vec_f + else: + return "infeasible" if len(assignments) == 0 else "failed" + + if len(nonlinear_syms) == 0: + for sym in linear_constraint_syms: + assignments[sym] = int(solution[sym_ordering[sym]]) + else: + chosen_sym = None + if len(priority_syms) != 0: + chosen_sym = np.random.choice( + sorted(list(priority_syms), key=lambda sym: sym._id) + ) + else: + assert len(linear_constraint_syms) != 0 + chosen_sym = np.random.choice( + sorted(list(linear_constraint_syms), key=lambda sym: sym._id) + ) + assignments[chosen_sym] = int(solution[sym_ordering[chosen_sym]]) + return assignments + + def solve_constraint( + self, + disjoint_constraint: DisjointConstraint, + *, + partial_solution: Optional[Solution] = None, + bound: int, + search_limit: int, + seed: Optional[int] = None, + ) -> Optional[Solution]: + if seed is not None: + np.random.seed(seed=seed) + if partial_solution is not None: + disjoint_constraint = disjoint_constraint.substitute( + partial_solution.substitutions + ) + + clauses = list( + clause + for clause in disjoint_constraint.clauses + if all(not constraint.is_unsolvable() for constraint in clause.constraints) + and len(clause.constraints) <= SIMUL_CONSTRAINT_LIMIT + ) + for _ in range(search_limit): + if len(clauses) == 0: + return None + chosen_clause = np.random.choice(clauses) + assert isinstance(chosen_clause, ConstraintClause) + chosen_clause_syms = chosen_clause.collect_syms() + chosen_extra_clauses: list[Constraint] = [] + failed_to_choose = False + for sym, extra_constraint in self.extra_constraints.items(): + if sym in chosen_clause_syms: + extra_constraint_clauses = list( + clause + for clause in extra_constraint.clauses + if all( + not constraint.is_unsolvable() + for constraint in clause.constraints + ) + ) + if len(extra_constraint_clauses) == 0: + failed_to_choose = True + break + chosen_extra_clause = np.random.choice(extra_constraint_clauses) + assert isinstance(chosen_extra_clause, ConstraintClause) + chosen_extra_clauses.extend(chosen_extra_clause.constraints) + if failed_to_choose: + continue + all_constraints = chosen_clause.constraints + tuple(chosen_extra_clauses) + assignment_result = self._solve_for_assignments(all_constraints, bound) + if assignment_result == "failed": + continue + elif assignment_result == "infeasible": + clauses = list(clause for clause in clauses if clause != chosen_clause) + else: + return self._make_solution_from_assignments( + ({} if partial_solution is None else partial_solution.substitutions) + | assignment_result + ) + return None + + def rename_sym_set( + self, syms: frozenset[Sym], free_vars: frozenset[Sym] + ) -> tuple[dict[Sym, Sym], dict[Sym, Sym]]: + var_renaming = {} + sym_renaming = {} + for var in free_vars: + var_sub = self.var_subs[var] + var_sub_syms = var_sub.collect_syms() + if len(var_sub_syms & syms) != 0: + sym_renaming |= {sym: Sym(sym.name()) for sym in var_sub_syms} + renamed_var = Sym(var.name()) + var_renaming[var] = renamed_var + self.var_subs[renamed_var] = var_sub.rename_syms(sym_renaming) + new_extra_constraints = {} + for sym, extra_constraint in self.extra_constraints.items(): + if ( + len(extra_constraint.collect_syms() & sym_renaming.keys()) != 0 + and sym in syms + ): + new_sym = Sym(sym.name()) + new_extra_constraints[new_sym] = extra_constraint.rename_syms( + sym_renaming + ) + sym_renaming[sym] = new_sym + self.extra_constraints |= new_extra_constraints + return ( + sym_renaming, + var_renaming, + ) diff --git a/src/exo/rewrite/chexo/coverage.py b/src/exo/rewrite/chexo/coverage.py new file mode 100644 index 000000000..aba759564 --- /dev/null +++ b/src/exo/rewrite/chexo/coverage.py @@ -0,0 +1,864 @@ +from dataclasses import dataclass, field +from typing import Generator, Optional, Union +import numpy as np + +from .constraint_solver import ( + Constraint, + ConstraintMaker, + DisjointConstraint, + TRUE_CONSTRAINT, + Expression, + Solution, +) +from ...core.prelude import Sym +from ...core.internal_cursors import NodePath +from ..new_eff import SchedulingError + + +@dataclass +class CoverageProgress: + covered_cases: int + total_cases: int + + def merge(self, other: "CoverageProgress") -> "CoverageProgress": + return CoverageProgress( + self.covered_cases + other.covered_cases, + self.total_cases + other.total_cases, + ) + + def is_finished(self) -> bool: + return self.covered_cases == self.total_cases + + +@dataclass +class CoverageSolverState: + current_constraint: DisjointConstraint + is_base_constraint: bool + current_solution: Solution + cm: ConstraintMaker + free_vars: frozenset[Sym] + bound: int + search_limit: int + + def update_solution( + self, new_constraint: DisjointConstraint, new_solution: Solution + ): + return CoverageSolverState( + new_constraint, + False, + new_solution, + self.cm, + self.free_vars, + self.bound, + self.search_limit, + ) + + +@dataclass(order=True, frozen=True) +class IndexedFiller: + index: int + placefiller: str + + +@dataclass +class CoverageSkeletonBranch: + true_child: "CoverageSkeletonNode" + false_child: "CoverageSkeletonNode" + + def write_coverage_declarations(self, decls: dict[Sym, str]): + self.true_child.write_coverage_declarations(decls) + self.false_child.write_coverage_declarations(decls) + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + yield from self.true_child.get_indexed_fillers() + yield from self.false_child.get_indexed_fillers() + + def get_coverage_syms(self) -> frozenset[Sym]: + return ( + self.true_child.get_coverage_syms() | self.false_child.get_coverage_syms() + ) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + self.true_child.update_coverage(coverage_result) + self.false_child.update_coverage(coverage_result) + + def get_coverage_progress(self) -> CoverageProgress: + return self.true_child.get_coverage_progress().merge( + self.false_child.get_coverage_progress() + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + uncovered_path = None + if self.true_child.visited and not self.false_child.visited: + uncovered_path = False + elif self.false_child.visited and not self.true_child.visited: + uncovered_path = True + + if uncovered_path is not None: + path_constraint = ( + self.true_child.get_complete_constraint() + if uncovered_path + else self.false_child.get_complete_constraint() + ) + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms(), + state.free_vars, + ) + new_constraint = state.current_constraint.intersect( + path_constraint.rename_syms(sym_renaming) + ) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if new_solution is None and state.is_base_constraint: + (self.true_child if uncovered_path else self.false_child).mark_visited() + if new_solution is not None: + if uncovered_path: + self.true_child.visited = True + else: + self.false_child.visited = True + return state.update_solution(new_constraint, new_solution) + elif self.true_child.visited and self.false_child.visited: + return self.false_child.solve_coverage( + self.true_child.solve_coverage(state) + ) + return state + + +@dataclass +class CoverageSkeletonNode: + coverage_sym: Optional[Sym] + parent_edge: Optional[tuple["CoverageSkeletonNode", DisjointConstraint]] + indexed_fillers: tuple[IndexedFiller, ...] + branches: list[CoverageSkeletonBranch] = field(default_factory=lambda: []) + visited: bool = False # mutable + + def write_coverage_declarations(self, decls: dict[Sym, str]): + if self.coverage_sym is not None: + decls[self.coverage_sym] = "false" + for branch in self.branches: + branch.write_coverage_declarations(decls) + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + for branch in self.branches: + yield from branch.get_indexed_fillers() + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset( + (self.coverage_sym,) if self.coverage_sym is not None else () + ).union(*tuple(branch.get_coverage_syms() for branch in self.branches)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + if self.coverage_sym is None: + self.visited = True + else: + covered = coverage_result[repr(self.coverage_sym)] + assert isinstance(covered, bool) + self.visited |= covered + for branch in self.branches: + branch.update_coverage(coverage_result) + + def get_complete_constraint(self) -> DisjointConstraint: + current_edge = self.parent_edge + result = TRUE_CONSTRAINT + while current_edge is not None: + result = result.intersect(current_edge[1]) + current_edge = current_edge[0].parent_edge + return result + + def get_coverage_progress(self) -> CoverageProgress: + result = CoverageProgress(1 if self.visited else 0, 1) + for branch in self.branches: + result = result.merge(branch.get_coverage_progress()) + return result + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + current_state = state + for branch in self.branches: + current_state = branch.solve_coverage(current_state) + return current_state + + def mark_visited(self): + self.visited = True + for branch in self.branches: + branch.true_child.mark_visited() + branch.false_child.mark_visited() + + +@dataclass +class MemoryAccess: + coverage_sym: Sym + node: CoverageSkeletonNode + index: tuple[Expression, ...] + indexed_fillers: tuple[IndexedFiller, ...] + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def make_renamed_constraint_and_indices( + self, state: CoverageSolverState + ) -> tuple[DisjointConstraint, tuple[Expression, ...], dict[Sym, Sym]]: + path_constraint = self.node.get_complete_constraint() + sym_renaming, var_renaming = state.cm.rename_sym_set( + path_constraint.collect_syms().union( + *(index_expr.collect_syms() for index_expr in self.index) + ), + state.free_vars, + ) + return ( + path_constraint.rename_syms(sym_renaming), + tuple(index_expr.rename_syms(sym_renaming) for index_expr in self.index), + var_renaming, + ) + + +@dataclass +class MemoryAccessPair: + access1: MemoryAccess + access2: MemoryAccess + visited_aliasing: bool = False # mutable + visited_nonaliasing: bool = False # mutable + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + yield from self.access1.get_indexed_fillers() + yield from self.access2.get_indexed_fillers() + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset((self.access1.coverage_sym, self.access2.coverage_sym)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + access1_view = coverage_result[repr(self.access1.coverage_sym)] + access2_view = coverage_result[repr(self.access2.coverage_sym)] + if isinstance(access1_view, memoryview): + assert isinstance(access2_view, memoryview) + access1_arr = np.asarray(access1_view) + access2_arr = np.asarray(access2_view) + self.visited_aliasing |= np.any(access1_arr & access2_arr) + self.visited_nonaliasing |= np.any(access1_arr & ~access2_arr) and np.any( + ~access1_arr & access2_arr + ) + else: + assert isinstance(access2_view, bool) + aliased = access1_view and access2_view + self.visited_aliasing |= aliased # nonaliasing not possible without tensor + + def get_coverage_progress(self) -> CoverageProgress: + return CoverageProgress( + (1 if self.visited_aliasing else 0) + + (1 if self.visited_nonaliasing else 0), + 2, + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + uncovered_path = None + if self.visited_aliasing and not self.visited_nonaliasing: + uncovered_path = False + elif self.visited_nonaliasing and not self.visited_aliasing: + uncovered_path = True + + if uncovered_path is not None: + ( + access1_path_constraint, + access1_indices, + _, + ) = self.access1.make_renamed_constraint_and_indices(state) + ( + access2_path_constraint, + access2_indices, + _, + ) = self.access2.make_renamed_constraint_and_indices(state) + path_constraints = access1_path_constraint.intersect( + access2_path_constraint + ) + alias_constraint = TRUE_CONSTRAINT + for index1, index2 in zip(access1_indices, access2_indices): + alias_constraint = alias_constraint.intersect( + Constraint( + index1.negate().add(index2), + False, + ).lift_to_disjoint_constraint() + ) + if not uncovered_path: + alias_constraint = alias_constraint.invert() + new_constraint = state.current_constraint.intersect( + path_constraints + ).intersect(alias_constraint) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + if uncovered_path: + self.visited_aliasing = True + else: + self.visited_nonaliasing = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + elif not self.visited_nonaliasing and not self.visited_aliasing: + ( + access1_path_constraint, + access1_indices, + _, + ) = self.access1.make_renamed_constraint_and_indices(state) + ( + access2_path_constraint, + access2_indices, + _, + ) = self.access2.make_renamed_constraint_and_indices(state) + path_constraints = access1_path_constraint.intersect( + access2_path_constraint + ) + new_constraint = state.current_constraint.intersect(path_constraints) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if new_solution is None: + self.visited_aliasing = True + self.visited_nonaliasing = True + else: + return state.update_solution(new_constraint, new_solution) + + return state + + +@dataclass +class FailureCondition: + coverage_sym: Sym + fail_cond: DisjointConstraint + node: CoverageSkeletonNode + indexed_fillers: tuple[IndexedFiller, ...] + visited_failure: bool = False # mutable + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset((self.coverage_sym,)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + failed = coverage_result[repr(self.coverage_sym)] + assert isinstance(failed, bool) + self.visited_failure |= failed + + def get_coverage_progress(self) -> CoverageProgress: + return CoverageProgress( + 1 if self.visited_failure else 0, + 1, + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + if not self.visited_failure: + path_constraint = self.node.get_complete_constraint().intersect( + self.fail_cond + ) + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms(), + state.free_vars, + ) + new_constraint = state.current_constraint.intersect( + path_constraint.rename_syms(sym_renaming) + ) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + self.visited_failure = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + return state + + +@dataclass +class SymbolicPoint: + index: Expression + + def collect_syms(self) -> frozenset[Sym]: + return self.index.collect_syms() + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "SymbolicPoint": + return SymbolicPoint(self.index.rename_syms(lookup)) + + +@dataclass +class SymbolicSlice: + lower_bound: Expression + upper_bound: Expression + + def collect_syms(self) -> frozenset[Sym]: + return self.lower_bound.collect_syms() | self.upper_bound.collect_syms() + + def rename_syms(self, lookup: dict[Sym, Sym]) -> "SymbolicSlice": + return SymbolicSlice( + self.lower_bound.rename_syms(lookup), self.upper_bound.rename_syms(lookup) + ) + + +SymbolicWindowIndex = Union[SymbolicPoint, SymbolicSlice] + + +@dataclass +class StagingBoundCheck: + upper_bound_violation_sym: Sym + lower_bound_violation_sym: Sym + staged_index: SymbolicWindowIndex + parent_index: SymbolicSlice + node: CoverageSkeletonNode + indexed_fillers: tuple[IndexedFiller, ...] + violated_upper_bound: bool = False + violated_lower_bound: bool = False + visited_upper_bound_violation: bool = False + visited_lower_bound_violation: bool = False + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset( + (self.upper_bound_violation_sym, self.lower_bound_violation_sym) + ) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + violated_upper_bound = coverage_result[repr(self.upper_bound_violation_sym)] + violated_lower_bound = coverage_result[repr(self.lower_bound_violation_sym)] + assert isinstance(violated_upper_bound, bool) and isinstance( + violated_lower_bound, bool + ) + self.violated_upper_bound |= violated_upper_bound + self.violated_lower_bound |= violated_lower_bound + + def get_coverage_progress(self) -> CoverageProgress: + return CoverageProgress( + (1 if self.visited_upper_bound_violation else 0) + + (1 if self.visited_lower_bound_violation else 0), + 2, + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + if not self.visited_upper_bound_violation: + out_of_bounds_cond = ( + Constraint( + self.staged_index.index.add(Expression.from_constant(1)) + .negate() + .add(self.parent_index.upper_bound), + True, + ).lift_to_disjoint_constraint() + if isinstance(self.staged_index, SymbolicPoint) + else Constraint( + self.staged_index.upper_bound.negate().add( + self.parent_index.upper_bound + ), + True, + ).lift_to_disjoint_constraint() + ) + path_constraint = self.node.get_complete_constraint().intersect( + out_of_bounds_cond + ) + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms(), + state.free_vars, + ) + new_constraint = state.current_constraint.intersect( + path_constraint.rename_syms(sym_renaming) + ) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + self.visited_upper_bound_violation = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + if not self.visited_lower_bound_violation: + out_of_bounds_cond = ( + Constraint( + self.staged_index.index.add(self.parent_index.lower_bound.negate()), + True, + ).lift_to_disjoint_constraint() + if isinstance(self.staged_index, SymbolicPoint) + else Constraint( + self.staged_index.lower_bound.add( + self.parent_index.lower_bound.negate() + ), + True, + ).lift_to_disjoint_constraint() + ) + path_constraint = self.node.get_complete_constraint().intersect( + out_of_bounds_cond + ) + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms(), + state.free_vars, + ) + new_constraint = state.current_constraint.intersect( + path_constraint.rename_syms(sym_renaming) + ) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + self.visited_lower_bound_violation = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + return state + + +# for stage_mem +@dataclass +class StagingOverlap: + overlap_sym: Sym + disjoint_sym: Sym + staged_window: tuple[SymbolicWindowIndex, ...] + access_window: tuple[SymbolicWindowIndex, ...] + node: CoverageSkeletonNode + access_cursor: NodePath + indexed_fillers: tuple[IndexedFiller, ...] + has_overlap: bool = False + has_disjoint_access: bool = False + visited_overlap: bool = False + visited_disjoint: bool = False + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset((self.overlap_sym, self.disjoint_sym)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + overlap = coverage_result[repr(self.overlap_sym)] + disjoint = coverage_result[repr(self.disjoint_sym)] + assert isinstance(overlap, bool) and isinstance(disjoint, bool) + self.has_overlap |= overlap + self.visited_overlap |= overlap + self.has_disjoint_access |= disjoint + self.visited_disjoint |= disjoint + + def get_coverage_progress(self) -> CoverageProgress: + return CoverageProgress( + (1 if self.visited_overlap else 0) + (1 if self.visited_disjoint else 0), + 2, + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + uncovered_overlap = None + if self.visited_overlap and not self.visited_disjoint: + uncovered_overlap = False + elif self.visited_disjoint and not self.visited_overlap: + uncovered_overlap = True + + if uncovered_overlap is not None: + path_constraint = self.node.get_complete_constraint() + sym_renaming, _ = state.cm.rename_sym_set( + path_constraint.collect_syms().union( + *( + (access_window_idx.collect_syms()) + for access_window_idx in self.access_window + ), + *( + staged_window_idx.collect_syms() + for staged_window_idx in self.staged_window + ), + ), + state.free_vars, + ) + path_constraint = path_constraint.rename_syms(sym_renaming) + overlap_constraint = TRUE_CONSTRAINT + for access_idx, staged_idx in zip( + map(lambda idx: idx.rename_syms(sym_renaming), self.access_window), + map(lambda idx: idx.rename_syms(sym_renaming), self.staged_window), + ): + if isinstance(staged_idx, SymbolicPoint) and isinstance( + access_idx, SymbolicPoint + ): + index_overlap_constraint = Constraint( + access_idx.index.negate().add(staged_idx.index), + False, + ).lift_to_disjoint_constraint() + elif isinstance(staged_idx, SymbolicSlice) and isinstance( + access_idx, SymbolicPoint + ): + index_overlap_constraint = ( + Constraint( + staged_idx.lower_bound.negate().add(access_idx.index), + True, + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + access_idx.index.negate() + .add(staged_idx.upper_bound) + .add(Expression.from_constant(-1)), + True, + ).lift_to_disjoint_constraint() + ) + ) + elif isinstance(staged_idx, SymbolicPoint) and isinstance( + access_idx, SymbolicSlice + ): + index_overlap_constraint = ( + Constraint( + access_idx.lower_bound.negate().add(staged_idx.index), + True, + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + staged_idx.index.negate() + .add(access_idx.upper_bound) + .add(Expression.from_constant(-1)), + True, + ).lift_to_disjoint_constraint() + ) + ) + elif isinstance(staged_idx, SymbolicSlice) and isinstance( + access_idx, SymbolicSlice + ): + index_overlap_constraint = ( + Constraint( + staged_idx.lower_bound.negate() + .add(access_idx.upper_bound) + .add(Expression.from_constant(-1)), + True, + ) + .lift_to_disjoint_constraint() + .intersect( + Constraint( + access_idx.lower_bound.negate() + .add(staged_idx.upper_bound) + .add(Expression.from_constant(-1)), + True, + ).lift_to_disjoint_constraint() + ) + ) + else: + assert False + overlap_constraint = overlap_constraint.intersect( + index_overlap_constraint + ) + if not uncovered_overlap: + overlap_constraint = overlap_constraint.invert() + new_constraint = state.current_constraint.intersect( + path_constraint + ).intersect(overlap_constraint) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + if uncovered_overlap: + self.visited_overlap = True + else: + self.visited_disjoint = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + return state + + +@dataclass +class ParallelAccess: + node: CoverageSkeletonNode + index: tuple[Expression, ...] + indexed_fillers: tuple[IndexedFiller, ...] + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def make_renamed_constraint_and_indices( + self, state: CoverageSolverState + ) -> tuple[DisjointConstraint, tuple[Expression, ...], dict[Sym, Sym]]: + path_constraint = self.node.get_complete_constraint() + sym_renaming, var_renaming = state.cm.rename_sym_set( + path_constraint.collect_syms().union( + *(index_expr.collect_syms() for index_expr in self.index) + ), + state.free_vars, + ) + return ( + path_constraint.rename_syms(sym_renaming), + tuple(index_expr.rename_syms(sym_renaming) for index_expr in self.index), + var_renaming, + ) + + +@dataclass +class ParallelAccessPair: + coverage_sym: Sym + iter_sym: Sym + access1: ParallelAccess + access2: ParallelAccess + indexed_fillers: tuple[IndexedFiller, ...] + has_aliased: bool = False + visited_aliasing: bool = False # mutable + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + yield from self.access1.get_indexed_fillers() + yield from self.access2.get_indexed_fillers() + for indexed_filler in self.indexed_fillers: + yield indexed_filler + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset((self.coverage_sym,)) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + aliased = coverage_result[repr(self.coverage_sym)] + assert isinstance(aliased, bool) + self.has_aliased |= aliased + self.visited_aliasing |= aliased + + def get_coverage_progress(self) -> CoverageProgress: + return CoverageProgress( + 1 if self.visited_aliasing else 0, + 1, + ) + + def solve_coverage(self, state: CoverageSolverState) -> CoverageSolverState: + if not self.visited_aliasing: + ( + access1_path_constraint, + access1_indices, + access1_var_renaming, + ) = self.access1.make_renamed_constraint_and_indices(state) + ( + access2_path_constraint, + access2_indices, + access2_var_renaming, + ) = self.access2.make_renamed_constraint_and_indices(state) + path_constraints = access1_path_constraint.intersect( + access2_path_constraint + ) + alias_constraint = TRUE_CONSTRAINT + for index1, index2 in zip(access1_indices, access2_indices): + alias_constraint = alias_constraint.intersect( + Constraint( + index1.negate().add(index2), + False, + ).lift_to_disjoint_constraint() + ) + + different_iteration_constraint = TRUE_CONSTRAINT + if ( + self.iter_sym in access1_var_renaming + and self.iter_sym in access2_var_renaming + ): + different_iteration_constraint = Constraint( + Expression.from_sym(access1_var_renaming[self.iter_sym]) + .negate() + .add(Expression.from_sym(access2_var_renaming[self.iter_sym])), + False, + ).invert() + + new_constraint = ( + state.current_constraint.intersect(path_constraints) + .intersect(alias_constraint) + .intersect(different_iteration_constraint) + ) + new_solution = state.cm.solve_constraint( + new_constraint, bound=state.bound, search_limit=state.search_limit + ) + if ( + new_solution is None and state.is_base_constraint + ) or new_solution is not None: + self.visited_aliasing = True + if new_solution is not None: + return state.update_solution(new_constraint, new_solution) + return state + + +CoverageTask = Union[ + CoverageSkeletonNode, + MemoryAccessPair, + FailureCondition, + StagingOverlap, + StagingBoundCheck, + ParallelAccessPair, +] + + +@dataclass +class CoverageSkeleton: + roots: tuple[CoverageSkeletonNode, ...] + aliasable_accesses: tuple[MemoryAccessPair, ...] + failure_conditions: tuple[FailureCondition, ...] + staging_overlaps: tuple[StagingOverlap, ...] + staging_bound_checks: tuple[StagingBoundCheck, ...] + parallel_accesses: tuple[ParallelAccessPair, ...] + free_vars: frozenset[Sym] + + def merge(self, other: "CoverageSkeleton") -> "CoverageSkeleton": + return CoverageSkeleton( + self.roots + other.roots, + self.aliasable_accesses + other.aliasable_accesses, + self.failure_conditions + other.failure_conditions, + self.staging_overlaps + other.staging_overlaps, + self.staging_bound_checks + other.staging_bound_checks, + self.parallel_accesses + other.parallel_accesses, + self.free_vars | other.free_vars, + ) + + def get_coverage_tasks(self) -> tuple[CoverageTask, ...]: + return ( + self.parallel_accesses + + self.staging_bound_checks + + self.staging_overlaps + + self.failure_conditions + + self.aliasable_accesses + + self.roots + ) + + def get_indexed_fillers(self) -> Generator[IndexedFiller, None, None]: + for task in self.get_coverage_tasks(): + yield from task.get_indexed_fillers() + + def get_coverage_syms(self) -> frozenset[Sym]: + return frozenset().union( + *tuple(task.get_coverage_syms() for task in self.get_coverage_tasks()), + ) + + def update_coverage(self, coverage_result: dict[str, Union[bool, memoryview]]): + for task in reversed(self.get_coverage_tasks()): + task.update_coverage(coverage_result) + + def get_coverage_progress(self) -> CoverageProgress: + result = CoverageProgress(0, 0) + for task in self.get_coverage_tasks(): + result = task.get_coverage_progress() + return result + + def solve_constraint_with_coverage( + self, + cm: ConstraintMaker, + base_constraint: DisjointConstraint, + *, + bound: int, + search_limit: int, + ) -> Optional[Solution]: + base_solution = cm.solve_constraint( + base_constraint, bound=bound, search_limit=search_limit + ) + if base_solution is None: + return None + state = CoverageSolverState( + base_constraint, + True, + base_solution, + cm, + self.free_vars, + bound, + search_limit, + ) + for task in self.get_coverage_tasks(): + state = task.solve_coverage(state) + return state.current_solution diff --git a/src/exo/stdlib/scheduling.py b/src/exo/stdlib/scheduling.py index 9e30eb177..26b867e78 100644 --- a/src/exo/stdlib/scheduling.py +++ b/src/exo/stdlib/scheduling.py @@ -29,6 +29,8 @@ bind_expr, commute_expr, left_reassociate_expr, + insert_mutate, + delete_stmt, # # subprocedure oriented operations extract_subproc, @@ -65,6 +67,7 @@ parallelize_loop, divide_with_recompute, divide_loop, + divide_loop_min, mult_loops, cut_loop, join_loops, diff --git a/tests/golden/test_chexo/test_arg_size_constraints.txt b/tests/golden/test_chexo/test_arg_size_constraints.txt new file mode 100644 index 000000000..84ae7cad0 --- /dev/null +++ b/tests/golden/test_chexo/test_arg_size_constraints.txt @@ -0,0 +1,5 @@ +union( + intersect( + -2 * a_m1 * b_m1 + -2 * a_m1 + -2 * a_m1 + -2 * b_m1 + -2 + -2 + 256 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_chexo/test_free_variables.txt b/tests/golden/test_chexo/test_free_variables.txt new file mode 100644 index 000000000..eac397ee3 --- /dev/null +++ b/tests/golden/test_chexo/test_free_variables.txt @@ -0,0 +1,3 @@ +a: (size, ) +b: (f32[a], ) +i: (index, None) \ No newline at end of file diff --git a/tests/golden/test_chexo/test_get_used_config_fields.txt b/tests/golden/test_chexo/test_get_used_config_fields.txt new file mode 100644 index 000000000..b50f0ce3c --- /dev/null +++ b/tests/golden/test_chexo/test_get_used_config_fields.txt @@ -0,0 +1 @@ +(TestConfig, a): f32 \ No newline at end of file diff --git a/tests/golden/test_chexo/test_path_constraints.txt b/tests/golden/test_chexo/test_path_constraints.txt new file mode 100644 index 000000000..a5f8740bc --- /dev/null +++ b/tests/golden/test_chexo/test_path_constraints.txt @@ -0,0 +1,9 @@ +union( + intersect( + 1 * j_a + -1 * j_b + 0 >= 0, + 1 * a_m1 + 1 + -1 * j_a + 1 * j_b + -1 >= 0, + 1 * a_m1 + 1 + -2 * i_a + 2 * i_b + -1 >= 0, + 1 * i_a + -1 * i_b + 0 >= 0, + 1 * a_m1 + 1 + -1 * i_a + 1 * i_b + -1 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_chexo/test_type_visitor.txt b/tests/golden/test_chexo/test_type_visitor.txt new file mode 100644 index 000000000..bfa6588c7 --- /dev/null +++ b/tests/golden/test_chexo/test_type_visitor.txt @@ -0,0 +1,9 @@ +Types: +a: size +b: f32[a] +c: f32 +i: index +Mems: +a: +b: +c: \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_disjunction.txt b/tests/golden/test_constraint_solver/test_disjunction.txt new file mode 100644 index 000000000..4d14fa6b7 --- /dev/null +++ b/tests/golden/test_constraint_solver/test_disjunction.txt @@ -0,0 +1,10 @@ +union( + intersect( + 3 + -1 * a_m1 + -1 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), + intersect( + 4 + -1 * b_m1 + -1 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_divmod.txt b/tests/golden/test_constraint_solver/test_divmod.txt new file mode 100644 index 000000000..bd2c9c003 --- /dev/null +++ b/tests/golden/test_constraint_solver/test_divmod.txt @@ -0,0 +1,12 @@ +union( + intersect( + 4 * a_m1 + 4 + 1 * b_m1 + 1 + -1 * c_m1 + -1 + -1 >= 0, + 5 + -1 * b_m1 + -1 + -1 >= 0, + 1 * rem + -3 == 0, + ), + intersect( + 3 + -1 * a_m1 + -1 >= 0, + 5 + -1 * b_m1 + -1 + -1 >= 0, + 1 * rem + -3 == 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_divmod_solve.txt b/tests/golden/test_constraint_solver/test_divmod_solve.txt new file mode 100644 index 000000000..95424a91d --- /dev/null +++ b/tests/golden/test_constraint_solver/test_divmod_solve.txt @@ -0,0 +1,3 @@ +a = 83 +b = 2 +c = 26 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_inversion.txt b/tests/golden/test_constraint_solver/test_inversion.txt new file mode 100644 index 000000000..3a38b12fd --- /dev/null +++ b/tests/golden/test_constraint_solver/test_inversion.txt @@ -0,0 +1,38 @@ +union( + intersect( + -3 + 1 * a_m1 + 1 + -1 >= 0, + -4 + 1 * b_m1 + 1 + -1 >= 0, + ), + intersect( + -3 + 1 * a_m1 + 1 + -1 >= 0, + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, + ), + intersect( + -3 + 1 * a_m1 + 1 + -1 >= 0, + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, + ), + intersect( + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, + -4 + 1 * b_m1 + 1 + -1 >= 0, + ), + intersect( + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, + ), + intersect( + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, + ), + intersect( + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, + -4 + 1 * b_m1 + 1 + -1 >= 0, + ), + intersect( + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, + 1 * a_m1 + 1 + 1 * b_m1 + 1 + -4 + -1 >= 0, + ), + intersect( + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, + -1 * a_m1 + -1 + -1 * b_m1 + -1 + 4 + -1 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_large_slack.txt b/tests/golden/test_constraint_solver/test_large_slack.txt new file mode 100644 index 000000000..7aa6ef31c --- /dev/null +++ b/tests/golden/test_constraint_solver/test_large_slack.txt @@ -0,0 +1 @@ +a = 79 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_make_constraint.txt b/tests/golden/test_constraint_solver/test_make_constraint.txt new file mode 100644 index 000000000..8e1b28b4e --- /dev/null +++ b/tests/golden/test_constraint_solver/test_make_constraint.txt @@ -0,0 +1,10 @@ +union( + intersect( + 4 * a_m1 + 4 + 1 * b_m1 + 1 + -1 * c_m1 + -1 + -1 >= 0, + 5 + -1 * b_m1 + -1 + -1 >= 0, + ), + intersect( + 3 + -1 * a_m1 + -1 >= 0, + 5 + -1 * b_m1 + -1 + -1 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_solve.txt b/tests/golden/test_constraint_solver/test_solve.txt new file mode 100644 index 000000000..03a44c85b --- /dev/null +++ b/tests/golden/test_constraint_solver/test_solve.txt @@ -0,0 +1,3 @@ +a = 93 +b = 3 +c = 65 \ No newline at end of file diff --git a/tests/golden/test_constraint_solver/test_xnor.txt b/tests/golden/test_constraint_solver/test_xnor.txt new file mode 100644 index 000000000..efe7ef902 --- /dev/null +++ b/tests/golden/test_constraint_solver/test_xnor.txt @@ -0,0 +1,12 @@ +union( + intersect( + -1 * b_m1 + -1 + 1 * a_m1 + 1 + -1 >= 0, + -1 * a_m1 + -1 + 0 + -1 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), + intersect( + 1 * b_m1 + 1 + -1 * a_m1 + -1 >= 0, + 1 * a_m1 + 1 + 0 >= 0, + 4 + -1 * a_m1 + -1 + -1 * b_m1 + -1 + -1 >= 0, + ), +) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_matmul.txt b/tests/golden/test_transpiler/test_matmul.txt new file mode 100644 index 000000000..ea657ab2c --- /dev/null +++ b/tests/golden/test_transpiler/test_matmul.txt @@ -0,0 +1,16 @@ +((a_4,b_5,c_6)=>{ +ctxt={} +if(!((($size_a_4_0)-(0))==($N_1))&&((($size_a_4_1)-(0))==($K_3)))return [1,ctxt,{}]; +if(!((($size_b_5_0)-(0))==($K_3))&&((($size_b_5_1)-(0))==($M_2)))return [1,ctxt,{}]; +if(!((($size_c_6_0)-(0))==($N_1))&&((($size_c_6_1)-(0))==($M_2)))return [1,ctxt,{}]; +for(let i_7=0;i_7<$N_1;i_7++){ +for(let j_8=0;j_8<$M_2;j_8++){ +for(let k_9=0;k_9<$K_3;k_9++){ +if(!((((i_7)+(0)))>=(0)&&(((i_7)+(0)))<($size_a_4_0))&&((((k_9)+(0)))>=(0)&&(((k_9)+(0)))<($size_a_4_1)))return [1,ctxt,{}]; +if(!((((k_9)+(0)))>=(0)&&(((k_9)+(0)))<($size_b_5_0))&&((((j_8)+(0)))>=(0)&&(((j_8)+(0)))<($size_b_5_1)))return [1,ctxt,{}]; +if(!((((i_7)+(0)))>=(0)&&(((i_7)+(0)))<($size_c_6_0))&&((((j_8)+(0)))>=(0)&&(((j_8)+(0)))<($size_c_6_1)))return [1,ctxt,{}]; +c_6[Math.imul($stride_c_6_0,((i_7)+(0)))+Math.imul($stride_c_6_1,((j_8)+(0)))]+=(a_4[Math.imul($stride_a_4_0,((i_7)+(0)))+Math.imul($stride_a_4_1,((k_9)+(0)))]*b_5[Math.imul($stride_b_5_0,((k_9)+(0)))+Math.imul($stride_b_5_1,((j_8)+(0)))]); +} +} +} +return [0,ctxt,];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_matmul_coverage.txt b/tests/golden/test_transpiler/test_matmul_coverage.txt new file mode 100644 index 000000000..6d8dd17df --- /dev/null +++ b/tests/golden/test_transpiler/test_matmul_coverage.txt @@ -0,0 +1,26 @@ +((a_4,b_5,c_6)=>{ +ctxt={} +let body_23=false;let body_25=false;let body_27=false;let skip_24=false;let skip_26=false;let skip_28=false; + + + +if(!((($size_a_4_0)-(0))==($N_1))&&((($size_a_4_1)-(0))==($K_3)))return [1,ctxt,{}]; +if(!((($size_b_5_0)-(0))==($K_3))&&((($size_b_5_1)-(0))==($M_2)))return [1,ctxt,{}]; +if(!((($size_c_6_0)-(0))==($N_1))&&((($size_c_6_1)-(0))==($M_2)))return [1,ctxt,{}]; +body_23||=(0<$N_1);skip_24||=(0>=$N_1); +for(let i_7=0;i_7<$N_1;i_7++){ +body_25||=(0<$M_2);skip_26||=(0>=$M_2); +for(let j_8=0;j_8<$M_2;j_8++){ +body_27||=(0<$K_3);skip_28||=(0>=$K_3); +for(let k_9=0;k_9<$K_3;k_9++){ + +if(!((((i_7)+(0)))>=(0)&&(((i_7)+(0)))<($size_a_4_0))&&((((k_9)+(0)))>=(0)&&(((k_9)+(0)))<($size_a_4_1)))return [1,ctxt,{}]; + +if(!((((k_9)+(0)))>=(0)&&(((k_9)+(0)))<($size_b_5_0))&&((((j_8)+(0)))>=(0)&&(((j_8)+(0)))<($size_b_5_1)))return [1,ctxt,{}]; + +if(!((((i_7)+(0)))>=(0)&&(((i_7)+(0)))<($size_c_6_0))&&((((j_8)+(0)))>=(0)&&(((j_8)+(0)))<($size_c_6_1)))return [1,ctxt,{}]; +c_6[Math.imul($stride_c_6_0,((i_7)+(0)))+Math.imul($stride_c_6_1,((j_8)+(0)))]+=(a_4[Math.imul($stride_a_4_0,((i_7)+(0)))+Math.imul($stride_a_4_1,((k_9)+(0)))]*b_5[Math.imul($stride_b_5_0,((k_9)+(0)))+Math.imul($stride_b_5_1,((j_8)+(0)))]); +} +} +} +return [0,ctxt,{body_23,body_25,body_27,skip_24,skip_26,skip_28}];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_nested_control_flow_coverage.txt b/tests/golden/test_transpiler/test_nested_control_flow_coverage.txt new file mode 100644 index 000000000..83489e76e --- /dev/null +++ b/tests/golden/test_transpiler/test_nested_control_flow_coverage.txt @@ -0,0 +1,25 @@ +((b_2)=>{ +ctxt={} +let access_18=false;let access_19=false;let access_22=false;let access_23=false;let body_12=false;let false_case_15=false;let false_case_21=false;let skip_13=false;let true_case_14=false;let true_case_20=false; +body_12||=(0<$n_1);skip_13||=(0>=$n_1); +for(let i_3=0;i_3<$n_1;i_3++){ +if((i_3<(($n_1/2)|0))){ +true_case_14=true; +access_18=true; +b_2[0]=2; +}else{ +false_case_15=true; +access_19=true; +b_2[0]=3; +} +if((i_3==($n_1-1))){ +true_case_20=true; +access_22=true; +b_2[0]+=1; +}else{ +false_case_21=true; +access_23=true; +b_2[0]+=2; +} +} +return [0,ctxt,{access_18,access_19,access_22,access_23,body_12,false_case_15,false_case_21,skip_13,true_case_14,true_case_20}];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_variable_length_array_coverage.txt b/tests/golden/test_transpiler/test_variable_length_array_coverage.txt new file mode 100644 index 000000000..7788ca7c3 --- /dev/null +++ b/tests/golden/test_transpiler/test_variable_length_array_coverage.txt @@ -0,0 +1,19 @@ +(()=>{ +ctxt={} +let access_11=new ArrayBuffer(1,{maxByteLength:16});let access_13=new ArrayBuffer(1,{maxByteLength:16});let body_9=false;let skip_10=false; +if(!($n_1>2))return [1,ctxt,{}]; +body_9||=(2<$n_1);skip_10||=(2>=$n_1); +for(let i_2=2;i_2<$n_1;i_2++){ +let size_b_3_0=i_2; +let stride_b_3_0=1; +while(Math.imul(size_b_3_0,stride_b_3_0)>access_11.maxByteLength){let temp_12=new ArrayBuffer(access_11.byteLength,{maxByteLength:2*access_11.maxByteLength});for(let i=0;iaccess_13.maxByteLength){let temp_14=new ArrayBuffer(access_13.byteLength,{maxByteLength:2*access_13.maxByteLength});for(let i=0;i=0))return [1,ctxt,{}]; +let b_3=new Int32Array(Math.imul(size_b_3_0,stride_b_3_0));for(let i=0;i=(0)&&((((i_2-1))+(0)))<(size_b_3_0)))return [1,ctxt,{}]; +b_3[Math.imul(stride_b_3_0,(((i_2-1))+(0)))]=0; +access_13[Math.imul((((i_2-2))+(0)),stride_b_3_0)]=1; +if(!(((((i_2-2))+(0)))>=(0)&&((((i_2-2))+(0)))<(size_b_3_0)))return [1,ctxt,{}]; +b_3[Math.imul(stride_b_3_0,(((i_2-2))+(0)))]=1; +} +return [0,ctxt,{access_11,access_13,body_9,skip_10}];}) \ No newline at end of file diff --git a/tests/golden/test_transpiler/test_window_coverage.txt b/tests/golden/test_transpiler/test_window_coverage.txt new file mode 100644 index 000000000..ba116af67 --- /dev/null +++ b/tests/golden/test_transpiler/test_window_coverage.txt @@ -0,0 +1,15 @@ +((a_1)=>{ +ctxt={} +let access_10=new ArrayBuffer(1,{maxByteLength:16});let access_8=new ArrayBuffer(1,{maxByteLength:16}); +while(Math.imul($size_a_1_0,$stride_a_1_0)>access_10.maxByteLength){let temp_11=new ArrayBuffer(access_10.byteLength,{maxByteLength:2*access_10.maxByteLength});for(let i=0;iaccess_8.maxByteLength){let temp_9=new ArrayBuffer(access_8.byteLength,{maxByteLength:2*access_8.maxByteLength});for(let i=0;i=(0)&&(((3)+(0)))<($size_a_1_0)))return [1,ctxt,{}]; +a_1[Math.imul($stride_a_1_0,((3)+(0)))]=2; +access_10[Math.imul(((2)+(lo_6)),$stride_a_1_0)]=1; +if(!((((2)+(lo_6)))>=(lo_6)&&(((2)+(lo_6)))<(hi_7)))return [1,ctxt,{}]; +a_1[Math.imul($stride_a_1_0,((2)+(lo_6)))]=3; +return [0,ctxt,{access_10,access_8}];}) \ No newline at end of file diff --git a/tests/test_chexo.py b/tests/test_chexo.py new file mode 100644 index 000000000..7c6da8226 --- /dev/null +++ b/tests/test_chexo.py @@ -0,0 +1,73 @@ +from __future__ import annotations +from exo.core.prelude import Sym + +from exo.rewrite.chexo.chexo import ( + TypeVisitor, + get_free_variables, + collect_path_constraints, +) +from exo.rewrite.chexo.constraint_solver import ConstraintMaker +from exo import proc, config +from exo.core.memory import StaticMemory + + +def stringify_dict(d): + def check_tuple(x): + if isinstance(x, tuple): + return f"({', '.join(str(x_item) for x_item in x)})" + else: + return str(x) + + return "\n".join( + sorted(f"{check_tuple(k)}: {check_tuple(v)}" for k, v in d.items()) + ) + + +def test_type_visitor(golden): + @proc + def foo(a: size, b: f32[a]): + for i in seq(0, a): + c: f32 @ StaticMemory + c = b[i] * 2 + + type_visitor = TypeVisitor() + type_visitor.visit(foo._loopir_proc) + types = stringify_dict(type_visitor.type_map) + mems = stringify_dict(type_visitor.mem_map) + assert golden == f"Types:\n{types}\nMems:\n{mems}" + + +def test_free_variables(golden): + @proc + def foo(a: size, b: f32[a]): + for i in seq(0, a): + c: f32 @ StaticMemory + c = b[i] * 2 + + type_visitor = TypeVisitor() + type_visitor.visit(foo._loopir_proc) + free_vars = get_free_variables( + type_visitor.type_map, + type_visitor.mem_map, + foo.find("c: _")._impl.as_block().expand(), + ) + assert golden == stringify_dict(free_vars) + + +def test_path_constraints(golden): + @proc + def foo(a: size, b: f32[a]): + for i in seq(0, a): + if 2 * i < a: + b[i] = 0 + else: + for j in seq(0, a): + b[j] = b[i] + + type_visitor = TypeVisitor() + type_visitor.visit(foo._loopir_proc) + cm = ConstraintMaker(type_visitor.type_map) + assert ( + golden + == collect_path_constraints(foo.find("b[j] = b[i]")._impl, cm).pretty_print() + ) diff --git a/tests/test_constraint_solver.py b/tests/test_constraint_solver.py new file mode 100644 index 000000000..793690b85 --- /dev/null +++ b/tests/test_constraint_solver.py @@ -0,0 +1,105 @@ +from __future__ import annotations +from exo.core.prelude import Sym + +from exo.rewrite.chexo.constraint_solver import ConstraintMaker, DisjointConstraint +from exo.core.LoopIR import T +from exo import proc +from exo.rewrite.chexo.chexo import TypeVisitor + + +def stringify_proc_constraint(p, invert=False): + p_type = TypeVisitor() + p_type.visit(p._loopir_proc) + constraint = ConstraintMaker(p_type.type_map).make_constraint( + p._loopir_proc.preds[0] + ) + return (constraint.invert() if invert else constraint).pretty_print() + + +def solve_proc_assertion(p): + p_type = TypeVisitor() + p_type.visit(p._loopir_proc) + cm = ConstraintMaker(p_type.type_map) + constraint = cm.make_constraint(p._loopir_proc.preds[0]) + return "\n".join( + sorted( + [ + f"{str(sym)} = {val}" + for sym, val in cm.solve_constraint( + constraint, bound=100, search_limit=10, seed=13 + ).var_assignments.items() + ] + ) + ) + + +def test_make_constraint(golden): + @proc + def foo(a: size, b: size, c: size): + assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) + pass + + assert golden == stringify_proc_constraint(foo) + + +def test_solve(golden): + @proc + def foo(a: size, b: size, c: size): + assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) + pass + + assert golden == solve_proc_assertion(foo) + + +def test_divmod(golden): + @proc + def foo(a: size, b: size, c: size): + assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) and (a % 4 == 3) + pass + + assert golden == stringify_proc_constraint(foo) + + +def test_divmod_solve(golden): + @proc + def foo(a: size, b: size, c: size): + assert ((a * 4 + b > c) or (a <= 3)) and (b < 5) and (a % 4 == 3) + pass + + assert golden == solve_proc_assertion(foo) + + +def test_large_slack(golden): + @proc + def foo(a: size): + assert a <= 1000000 + pass + + assert golden == solve_proc_assertion(foo) + + +def test_disjunction(golden): + @proc + def foo(a: size, b: size): + assert (a <= 3 or b <= 4) and (a + b < 4) + pass + + assert golden == stringify_proc_constraint(foo) + + +def test_xnor(golden): + @proc + def foo(a: size, b: size): + assert (a <= b) == (a >= 0) and (a + b < 4) + pass + + assert golden == stringify_proc_constraint(foo) + + +def test_inversion(golden): + @proc + def foo(a: size, b: size): + assert (a <= 3 or b <= 4) and (a + b == 4) + pass + + assert golden == stringify_proc_constraint(foo, True) diff --git a/tests/test_transpiler.py b/tests/test_transpiler.py new file mode 100644 index 000000000..3b201c285 --- /dev/null +++ b/tests/test_transpiler.py @@ -0,0 +1,105 @@ +from __future__ import annotations +from exo.core.prelude import Sym + +from exo.rewrite.chexo.constraint_solver import ConstraintMaker, DisjointConstraint +from exo.core.LoopIR import T +from exo import proc +from exo.rewrite.chexo.chexo import TypeVisitor +from exo.rewrite.chexo.LoopIR_transpiler import Transpiler, CoverageArgs + + +def get_coverage_args(p) -> CoverageArgs: + p_type = TypeVisitor() + p_type.visit(p._loopir_proc) + cm = ConstraintMaker(p_type.type_map) + return CoverageArgs(cm) + + +def test_matmul(golden): + Sym._unq_count = 1 + + @proc + def matmul(N: size, M: size, K: size, a: f32[N, K], b: f32[K, M], c: f32[N, M]): + for i in seq(0, N): + for j in seq(0, M): + for k in seq(0, K): + c[i, j] += a[i, k] * b[k, j] + + assert golden == Transpiler(matmul._loopir_proc).get_javascript_template().template + + +def test_matmul_coverage(golden): + Sym._unq_count = 1 + + @proc + def matmul(N: size, M: size, K: size, a: f32[N, K], b: f32[K, M], c: f32[N, M]): + for i in seq(0, N): + for j in seq(0, M): + for k in seq(0, K): + c[i, j] += a[i, k] * b[k, j] + + assert ( + golden + == Transpiler(matmul._loopir_proc, get_coverage_args(matmul)) + .get_javascript_template() + .template + ) + + +def test_window_coverage(golden): + Sym._unq_count = 1 + + @proc + def foo(a: i32[16]): + a_win = a[1:8] + a[3] = 2 + a_win[2] = 3 + + assert ( + golden + == Transpiler(foo._loopir_proc, get_coverage_args(foo)) + .get_javascript_template() + .template + ) + + +def test_variable_length_array_coverage(golden): + Sym._unq_count = 1 + + @proc + def foo(n: size): + assert n > 2 + for i in seq(2, n): + b: i32[i] + b[i - 1] = 0 + b[i - 2] = 1 + + assert ( + golden + == Transpiler(foo._loopir_proc, get_coverage_args(foo)) + .get_javascript_template() + .template + ) + + +def test_nested_control_flow_coverage(golden): + Sym._unq_count = 1 + + @proc + def foo(n: size, b: f32): + for i in seq(0, n): + if i < n / 2: + b = 2 + else: + b = 3 + if i == n - 1: + b += 1 + else: + b += 2 + + assert ( + golden + == Transpiler(foo._loopir_proc, get_coverage_args(foo)) + .get_javascript_template() + .template + )