diff --git a/python/examples/workload/example.py b/python/examples/workload/example.py new file mode 100644 index 0000000..1753c0e --- /dev/null +++ b/python/examples/workload/example.py @@ -0,0 +1,181 @@ +""" +Workload example: Element-wise sum of two (M, N) float32 arrays on CPU. +""" +import numpy as np +from mlir import ir +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor +from mlir.dialects import func, linalg, bufferization +from mlir.dialects import transform +from functools import cached_property +from lighthouse import Workload +from lighthouse.utils.mlir import ( + apply_registered_pass, + canonicalize, + cse, + match, +) +from lighthouse.utils.execution import ( + lower_payload, + execute, + benchmark, +) + + +class ElementwiseSum(Workload): + """ + Computes element-wise sum of (M, N) float32 arrays on CPU. + + We can construct the input arrays and compute the reference solution in + Python with Numpy. + + We use @cached_property to store the inputs and reference solution in the + object so that they are only computed once. + """ + + def __init__(self, M, N): + self.M = M + self.N = N + self.dtype = np.float32 + self.context = ir.Context() + self.location = ir.Location.unknown(context=self.context) + + @cached_property + def _input_arrays(self): + print(" * Generating input arrays...") + np.random.seed(2) + A = np.random.rand(self.M, self.N).astype(self.dtype) + B = np.random.rand(self.M, self.N).astype(self.dtype) + C = np.zeros((self.M, self.N), dtype=self.dtype) + return [A, B, C] + + @cached_property + def _reference_solution(self): + print(" * Computing reference solution...") + A, B, _ = self._input_arrays + return A + B + + def get_input_arrays(self, execution_engine): + return [ + get_ranked_memref_descriptor(a) for a in self._input_arrays + ] + + def verify(self, execution_engine, verbose: int = 0) -> bool: + C = self._input_arrays[2] + C_ref = self._reference_solution + if verbose > 1: + print("Reference solution:") + print(C_ref) + print("Computed solution:") + print(C) + success = np.allclose(C, C_ref) + if verbose: + if success: + print("PASSED") + else: + print("FAILED Result mismatch!") + return success + + def requirements(self): + return [] + + def get_complexity(self): + nbytes = np.dtype(self.dtype).itemsize + flop_count = self.M * self.N # one addition per element + memory_reads = 2 * self.M * self.N * nbytes # read A and B + memory_writes = self.M * self.N * nbytes # write C + return (flop_count, memory_reads, memory_writes) + + def payload_module(self): + with self.context, self.location: + float32_t = ir.F32Type.get() + shape = (self.M, self.N) + tensor_t = ir.RankedTensorType.get(shape, float32_t) + memref_t = ir.MemRefType.get(shape, float32_t) + mod = ir.Module.create() + with ir.InsertionPoint(mod.body): + args = [memref_t, memref_t, memref_t] + f = func.FuncOp(self.payload_function_name, (tuple(args), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + A = f.arguments[0] + B = f.arguments[1] + C = f.arguments[2] + a_tensor = bufferization.ToTensorOp(tensor_t, A, restrict=True) + b_tensor = bufferization.ToTensorOp(tensor_t, B, restrict=True) + c_tensor = bufferization.ToTensorOp( + tensor_t, C, restrict=True, writable=True + ) + add = linalg.add(a_tensor, b_tensor, outs=[c_tensor]) + bufferization.MaterializeInDestinationOp( + None, add, C, restrict=True, writable=True + ) + func.ReturnOp(()) + return mod + + def schedule_module(self, dump_kernel=None, parameters=None): + with self.context, self.location: + schedule_module = ir.Module.create() + schedule_module.operation.attributes[ + "transform.with_named_sequence"] = (ir.UnitAttr.get()) + with ir.InsertionPoint(schedule_module.body): + named_sequence = transform.NamedSequenceOp( + "__transform_main", + [transform.AnyOpType.get()], + [], + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + ) + with ir.InsertionPoint(named_sequence.body): + anytype = transform.AnyOpType.get() + func = match(named_sequence.bodyTarget, ops={"func.func"}) + mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + mod = apply_registered_pass(mod, "one-shot-bufferize") + mod = apply_registered_pass(mod, "convert-linalg-to-loops") + cse(mod) + canonicalize(mod) + + if dump_kernel == "bufferized": + transform.YieldOp() + return schedule_module + + mod = apply_registered_pass(mod, "convert-scf-to-cf") + mod = apply_registered_pass(mod, "finalize-memref-to-llvm") + mod = apply_registered_pass(mod, "convert-cf-to-llvm") + mod = apply_registered_pass(mod, "convert-arith-to-llvm") + mod = apply_registered_pass(mod, "convert-func-to-llvm") + mod = apply_registered_pass(mod, + "reconcile-unrealized-casts") + transform.YieldOp() + + return schedule_module + + +if __name__ == "__main__": + wload = ElementwiseSum(400, 400) + + print(" Dump kernel ".center(60, "-")) + lower_payload(wload, dump_kernel="bufferized", dump_schedule=True) + + print(" Execute 1 ".center(60, "-")) + execute(wload, verbose=2) + + print(" Execute 2 ".center(60, "-")) + execute(wload, verbose=1) + + print(" Benchmark ".center(60, "-")) + times = benchmark(wload) + times *= 1e6 # convert to microseconds + # compute statistics + mean = np.mean(times) + min = np.min(times) + max = np.max(times) + std = np.std(times) + print(f"Timings (us): " + f"mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}") + flop_count = wload.get_complexity()[0] + gflops = flop_count / (mean * 1e-6) / 1e9 + print(f"Throughput: {gflops:.2f} GFLOPS") diff --git a/python/examples/workload/example_mlir.py b/python/examples/workload/example_mlir.py new file mode 100644 index 0000000..68537b6 --- /dev/null +++ b/python/examples/workload/example_mlir.py @@ -0,0 +1,217 @@ +""" +Workload example: Element-wise sum of two (M, N) float32 arrays on CPU. + +In this example, allocation and deallocation of input arrays is done in MLIR. +""" +import numpy as np +from mlir import ir +from mlir.runtime.np_to_memref import ( + ranked_memref_to_numpy, + make_nd_memref_descriptor, + as_ctype, +) +from mlir.dialects import func, linalg, bufferization, arith, memref +from mlir.dialects import transform +import ctypes +from contextlib import contextmanager +from lighthouse import Workload +from lighthouse.utils.mlir import ( + apply_registered_pass, + canonicalize, + cse, + match, +) +from lighthouse.utils import get_packed_arg +from lighthouse.utils.execution import ( + lower_payload, + execute, + benchmark, +) +from example import ElementwiseSum + + +def emit_host_alloc(mod, suffix, element_type, rank=2): + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank*(dyn,), element_type) + index_t = ir.IndexType.get() + i32_t = ir.IntegerType.get_signless(32) + with ir.InsertionPoint(mod.body): + f = func.FuncOp( + "host_alloc_" + suffix, (rank*(i32_t,), (memref_dyn_t,)) + ) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + dims = [ + arith.IndexCastOp(index_t, a) for a in list(f.arguments) + ] + alloc = memref.alloc(memref_dyn_t, dims, []) + func.ReturnOp((alloc,)) + + +def emit_host_dealloc(mod, suffix, element_type, rank=2): + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank*(dyn,), element_type) + with ir.InsertionPoint(mod.body): + f = func.FuncOp("host_dealloc_" + suffix, ((memref_dyn_t,), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + memref.dealloc(f.arguments[0]) + func.ReturnOp(()) + + +def emit_fill_constant(mod, suffix, value, element_type, rank=2): + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank*(dyn,), element_type) + with ir.InsertionPoint(mod.body): + f = func.FuncOp("host_fill_constant_" + suffix, ((memref_dyn_t,), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + const = arith.constant(element_type, value) + linalg.fill(const, outs=[f.arguments[0]]) + func.ReturnOp(()) + + +def emit_fill_random(mod, suffix, element_type, min=0.0, max=1.0, seed=2): + rank = 2 + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank*(dyn,), element_type) + i32_t = ir.IntegerType.get_signless(32) + f64_t = ir.F64Type.get() + with ir.InsertionPoint(mod.body): + f = func.FuncOp("host_fill_random_" + suffix, ((memref_dyn_t,), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + min_cst = arith.constant(f64_t, min) + max_cst = arith.constant(f64_t, max) + seed_cst = arith.constant(i32_t, seed) + linalg.fill_rng_2d(min_cst, max_cst, seed_cst, outs=[f.arguments[0]]) + func.ReturnOp(()) + + +class ElementwiseSumMLIRAlloc(ElementwiseSum): + """ + Computes element-wise sum of (M, N) float32 arrays on CPU. + + Extends ElementwiseSum by allocating input arrays in MLIR. + """ + + def __init__(self, M, N): + super().__init__(M, N) + # keep track of allocated memrefs + self.memrefs = {} + + def _allocate_array(self, name, execution_engine): + if name in self.memrefs: + return self.memrefs[name] + alloc_func = execution_engine.lookup("host_alloc_f32") + shape = (self.M, self.N) + mref = make_nd_memref_descriptor(len(shape), as_ctype(self.dtype))() + ptr_mref = ctypes.pointer(ctypes.pointer(mref)) + ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape] + alloc_func(get_packed_arg([ptr_mref, *ptr_dims])) + self.memrefs[name] = mref + return mref + + def _allocate_inputs(self, execution_engine): + self._allocate_array("A", execution_engine) + self._allocate_array("B", execution_engine) + self._allocate_array("C", execution_engine) + + def _deallocate_all(self, execution_engine): + for mref in self.memrefs.values(): + dealloc_func = execution_engine.lookup("host_dealloc_f32") + ptr_mref = ctypes.pointer(ctypes.pointer(mref)) + dealloc_func(get_packed_arg([ptr_mref])) + self.memrefs = {} + + @contextmanager + def allocate(self, execution_engine): + try: + self._allocate_inputs(execution_engine) + yield None + finally: + self._deallocate_all(execution_engine) + + def get_input_arrays(self, execution_engine): + A = self._allocate_array("A", execution_engine) + B = self._allocate_array("B", execution_engine) + C = self._allocate_array("C", execution_engine) + + # initialize with MLIR + fill_zero_func = execution_engine.lookup("host_fill_constant_zero_f32") + fill_random_func = execution_engine.lookup("host_fill_random_f32") + fill_zero_func(get_packed_arg([ctypes.pointer(ctypes.pointer(C))])) + fill_random_func(get_packed_arg([ctypes.pointer(ctypes.pointer(A))])) + fill_random_func(get_packed_arg([ctypes.pointer(ctypes.pointer(B))])) + + return [A, B, C] + + def verify(self, execution_engine, verbose: int = 0) -> bool: + # compute reference solution with numpy + A = ranked_memref_to_numpy([self.memrefs["A"]]) + B = ranked_memref_to_numpy([self.memrefs["B"]]) + C = ranked_memref_to_numpy([self.memrefs["C"]]) + C_ref = A + B + if verbose > 1: + print("Reference solution:") + print(C_ref) + print("Computed solution:") + print(C) + success = np.allclose(C, C_ref) + + # Alternatively we could have done the verification in MLIR by emitting + # a check function. + # Here we just call the payload function again. + # self._allocate_array("C_ref", execution_engine) + # func = execution_engine.lookup("payload") + # func(get_packed_arg([ + # ctypes.pointer(ctypes.pointer(self.memrefs["A"])), + # ctypes.pointer(ctypes.pointer(self.memrefs["B"])), + # ctypes.pointer(ctypes.pointer(self.memrefs["C_ref"])), + # ])) + # Check correctness with numpy. + # C = ranked_memref_to_numpy([self.memrefs["C"]]) + # C_ref = ranked_memref_to_numpy([self.memrefs["C_ref"]]) + # success = np.allclose(C, C_ref) + + if verbose: + if success: + print("PASSED") + else: + print("FAILED Result mismatch!") + return success + + def payload_module(self): + mod = super().payload_module() + # extend the payload module with de/alloc/fill functions + with self.context, self.location: + float32_t = ir.F32Type.get() + emit_host_alloc(mod, "f32", float32_t) + emit_host_dealloc(mod, "f32", float32_t) + emit_fill_constant(mod, "zero_f32", 0.0, float32_t) + emit_fill_random(mod, "f32", float32_t, min=-1.0, max=1.0) + return mod + + +if __name__ == "__main__": + wload = ElementwiseSumMLIRAlloc(400, 400) + + print(" Dump kernel ".center(60, "-")) + lower_payload(wload, dump_kernel="bufferized", dump_schedule=False) + + print(" Execute ".center(60, "-")) + execute(wload, verbose=2) + + print(" Benchmark ".center(60, "-")) + times = benchmark(wload) + times *= 1e6 # convert to microseconds + # compute statistics + mean = np.mean(times) + min = np.min(times) + max = np.max(times) + std = np.std(times) + print(f"Timings (us): " + f"mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}") + flop_count = wload.get_complexity()[0] + gflops = flop_count / (mean * 1e-6) / 1e9 + print(f"Throughput: {gflops:.2f} GFLOPS") diff --git a/python/lighthouse/__init__.py b/python/lighthouse/__init__.py index 1ac008e..4d700c2 100644 --- a/python/lighthouse/__init__.py +++ b/python/lighthouse/__init__.py @@ -1 +1,3 @@ __version__ = "0.1.0a1" + +from .workload import Workload diff --git a/python/lighthouse/utils/execution.py b/python/lighthouse/utils/execution.py new file mode 100644 index 0000000..89e6957 --- /dev/null +++ b/python/lighthouse/utils/execution.py @@ -0,0 +1,239 @@ +""" +Execution engine utility functions. +""" +import numpy as np +import ctypes +import os +from mlir import ir +from mlir.dialects.transform import interpreter as transform_interpreter +from mlir.dialects import func, arith, scf, memref +from mlir.execution_engine import ExecutionEngine +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor +from lighthouse.utils.mlir import get_mlir_library_path +from lighthouse.utils import get_packed_arg +from lighthouse import Workload +from typing import Optional + + +def get_engine(payload_module, requirements=None, opt_level=3) -> ExecutionEngine: + requirements = requirements or [] + context = ir.Context() + location = ir.Location.unknown(context) + required_libs = { + "levelzero": ( + ["libmlir_levelzero_runtime.so"], + "Did you compile LLVM with -DMLIR_ENABLE_LEVELZERO_RUNNER=1?", + ), + "mlir_runner": (["libmlir_runner_utils.so"], ""), + "mlir_c_runner": (["libmlir_c_runner_utils.so"], ""), + } + libs = [] + lib_dir = os.path.join(get_mlir_library_path()) + for r in requirements: + if r not in required_libs: + raise ValueError(f"Unknown execution engine requirement: {r}") + so_files, hint = required_libs[r] + for f in so_files: + so_path = os.path.join(lib_dir, f) + if not os.path.isfile(so_path): + msg = f"Could not find shared library {so_path}" + if hint: + msg += "\n" + hint + raise ValueError(msg) + libs.append(so_path) + with context, location: + execution_engine = ExecutionEngine( + payload_module, opt_level=opt_level, shared_libs=libs + ) + execution_engine.initialize() + return execution_engine + + +def apply_transform_schedule( + payload_module, + schedule_module, + context, + location, + dump_kernel: Optional[str] = None, + dump_schedule: bool = False, +): + if not dump_kernel or dump_kernel != "initial": + with context, location: + # invoke transform interpreter directly + transform_interpreter.apply_named_sequence( + payload_root=payload_module, + transform_root=schedule_module.body.operations[0], + transform_module=schedule_module, + ) + if dump_kernel: + print(payload_module) + if dump_schedule: + print(schedule_module) + + +def lower_payload( + workload, + dump_kernel: Optional[str] = None, + dump_schedule: bool = False, + schedule_parameters: Optional[dict] = None, +) -> ir.Module: + payload_module = workload.payload_module() + schedule_module = workload.schedule_module( + dump_kernel=dump_kernel, parameters=schedule_parameters + ) + apply_transform_schedule( + payload_module, + schedule_module, + workload.context, + workload.location, + dump_kernel=dump_kernel, + dump_schedule=dump_schedule, + ) + return payload_module + + +def execute( + workload, + check_correctness: bool = True, + schedule_parameters: Optional[dict] = None, + verbose: int = 0, +): + # lower payload with schedule + payload_module = lower_payload( + workload, schedule_parameters=schedule_parameters + ) + # get execution engine + engine = get_engine( + payload_module, requirements=workload.requirements() + ) + + with workload.allocate(execution_engine=engine): + # prepare function arguments + inputs = workload.get_input_arrays(execution_engine=engine) + pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] + packed_args = get_packed_arg(pointers) + + # handle to payload function + payload_func = engine.lookup(workload.payload_function_name) + + # call function + payload_func(packed_args) + + if check_correctness: + success = workload.verify(execution_engine=engine, verbose=verbose) + if not success: + raise ValueError("Benchmark verification failed.") + + +def emit_benchmark_function( + payload_module: ir.Module, workload: Workload, nruns: int, nwarmup: int +): + """ + Emit a benchmark function that calls payload function and times it. + + Every function call is timed separately. Returns the times (seconds) in a + memref. + """ + # find original payload function + payload_func = None + for op in payload_module.operation.regions[0].blocks[0]: + if (isinstance(op, func.FuncOp) and + str(op.name).strip('"') == workload.payload_function_name): + payload_func = op + break + assert payload_func is not None, "Could not find payload function" + payload_arguments = payload_func.type.inputs + # emit benchmark function + with workload.context, workload.location: + with ir.InsertionPoint(payload_module.body): + # define rtclock function + f64_t = ir.F64Type.get() + f = func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + # emit new function + time_memref_t = ir.MemRefType.get((nruns,), f64_t) + args = payload_arguments + [time_memref_t] + f = func.FuncOp("benchmark", (tuple(args), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + index_t = ir.IndexType.get() + zero = arith.ConstantOp(index_t, 0) + one = arith.ConstantOp(index_t, 1) + # call payload for warmup runs + nwarmup_cst = arith.ConstantOp(index_t, nwarmup) + for_op = scf.ForOp(zero, nwarmup_cst, one) + with ir.InsertionPoint(for_op.body): + func.CallOp( + payload_func, list(f.arguments[:len(payload_arguments)]) + ) + scf.YieldOp(()) + # call payload for benchmark runs, time every call separately + nruns_cst = arith.ConstantOp(index_t, nruns) + for_op = scf.ForOp(zero, nruns_cst, one) + i = for_op.induction_variable + with ir.InsertionPoint(for_op.body): + tic = func.CallOp((f64_t,), "rtclock", ()).result + func.CallOp( + payload_func, list(f.arguments[:len(payload_arguments)]) + ) + toc = func.CallOp((f64_t,), "rtclock", ()).result + time = arith.SubFOp(toc, tic) + memref.StoreOp(time, f.arguments[-1], [i]) + scf.YieldOp(()) + func.ReturnOp(()) + + +def benchmark( + workload, + nruns: int = 100, + nwarmup: int = 10, + schedule_parameters: Optional[dict] = None, + check_correctness: bool = True, + verbose: int = 0, +) -> np.ndarray: + + # get original payload module + payload_module = workload.payload_module() + + # add benchmark function with timing + emit_benchmark_function(payload_module, workload, nruns, nwarmup) + + # lower + apply_transform_schedule( + payload_module, + workload.schedule_module(parameters=schedule_parameters), + workload.context, + workload.location, + ) + # get execution engine, rtclock requires mlir_c_runner + requirements = workload.requirements() + if "mlir_c_runner" not in requirements: + requirements.append("mlir_c_runner") + engine = get_engine( + payload_module, requirements=requirements + ) + + with workload.allocate(execution_engine=engine): + inputs = workload.get_input_arrays(execution_engine=engine) + pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] + if check_correctness: + # call payload once to verify correctness + # prepare function arguments + packed_args = get_packed_arg(pointers) + + payload_func = engine.lookup(workload.payload_function_name) + payload_func(packed_args) + success = workload.verify(execution_engine=engine, verbose=verbose) + if not success: + raise ValueError("Benchmark verification failed.") + + # allocate buffer for timings and prepare arguments + time_array = np.zeros((nruns,), dtype=np.float64) + time_memref = get_ranked_memref_descriptor(time_array) + time_pointer = ctypes.pointer(ctypes.pointer(time_memref)) + packed_args_with_time = get_packed_arg(pointers + [time_pointer]) + + # call benchmark function + benchmark_func = engine.lookup("benchmark") + benchmark_func(packed_args_with_time) + + return time_array diff --git a/python/lighthouse/utils/mlir.py b/python/lighthouse/utils/mlir.py new file mode 100644 index 0000000..90b248b --- /dev/null +++ b/python/lighthouse/utils/mlir.py @@ -0,0 +1,36 @@ +""" +MLIR utility functions. +""" +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured +import os + + +def apply_registered_pass(*args, **kwargs): + return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs) + + +def match(*args, **kwargs): + return structured.MatchOp(transform.AnyOpType.get(), *args, **kwargs) + + +def cse(op): + transform.ApplyCommonSubexpressionEliminationOp(op) + + +def canonicalize(op): + with ir.InsertionPoint(transform.ApplyPatternsOp(op).patterns): + transform.ApplyCanonicalizationPatternsOp() + + +def get_mlir_library_path(): + pkg_path = ir.__file__ + if "python_packages" in pkg_path: + # looks like a local mlir install + path = pkg_path.split("python_packages")[0] + os.sep + "lib" + else: + # maybe installed in python path + path = os.path.split(pkg_path)[0] + os.sep + "_mlir_libs" + assert os.path.isdir(path) + return path diff --git a/python/lighthouse/workload.py b/python/lighthouse/workload.py new file mode 100644 index 0000000..9f03868 --- /dev/null +++ b/python/lighthouse/workload.py @@ -0,0 +1,77 @@ +""" +Abstract base class for workloads. + +Defines the expected interface for generic workload execution methods. +""" +from mlir import ir +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Optional + + +class Workload(ABC): + """ + Abstract base class for workloads. + + A workload is defined by a fixed payload function and problem size. + Different realizations of the workload can be obtained by altering the + lowering schedule. + + The MLIR payload function should take input arrays as memrefs and return + nothing. + """ + payload_function_name: str = "payload" + + @abstractmethod + def requirements(self) -> list[str]: + """Return a list of requirements for the execution engine.""" + pass + + @abstractmethod + def payload_module(self) -> ir.Module: + """Generate the MLIR module containing the payload function.""" + pass + + @abstractmethod + def schedule_module( + self, + dump_kernel: Optional[str] = None, + parameters: Optional[dict] = None, + ) -> ir.Module: + """Generate the MLIR module containing the transform schedule.""" + pass + + @abstractmethod + def get_input_arrays(self, execution_engine) -> list: + """ + Return the input arrays for the payload function as memrefs. + + Allocation and initialization of the input arrays should be done here. + """ + pass + + @contextmanager + def allocate(self, execution_engine): + """ + Allocate any necessary memory for the workload. + + Override this method if the workload requires memory management.""" + try: + yield None + finally: + pass + + @abstractmethod + def verify(self, execution_engine, verbose: int = 0) -> bool: + """Verify the correctness of the computation.""" + pass + + @abstractmethod + def get_complexity(self) -> list: + """ + Return the computational complexity of the workload. + + Returns a tuple (flop_count, memory_reads, memory_writes). Memory + reads/writes are in bytes. + """ + pass