From f7752ef7ef16ad9db8f367935950dbf1ec06e6b8 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 10 Jan 2025 04:24:59 +0000 Subject: [PATCH 1/4] [Frontend] Split scheduling module --- .../mlir/mlir_codegen_backend.py | 156 ----------------- PyTorchSimFrontend/mlir/mlir_scheduling.py | 161 ++++++++++++++++++ Scheduler/scheduler.py | 4 +- 3 files changed, 164 insertions(+), 157 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/mlir_scheduling.py diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 09c5aa34..d5336fdb 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -11,20 +11,15 @@ from typing import Dict from collections import OrderedDict import torch -from torch._inductor import dependencies, config from torch._inductor.codegen import cpp, wrapper, common -from torch._inductor.scheduler import BaseScheduling from torch._inductor.virtualized import V, _ops as ops from torch._inductor.codecache import write_atomic, write -from Simulator.simulator import BackendSimulator -from PyTorchSimFrontend import extension_config from torch._inductor.utils import ( IndentedBuffer, is_welford_reduction, ) import PyTorchSimFrontend.extension_codecache as extension_codecache - from . import mlir_common def reduction_init(reduction_type, dtype): @@ -1319,8 +1314,6 @@ def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape def roundup_vectorlane(self, size, amp=1): return ((size + self.vector_lane - 1) // self.vector_lane) * self.vector_lane * amp -from . import mlir_lowering - @dataclasses.dataclass class LoopLevel: var: sympy.Expr @@ -1364,152 +1357,3 @@ def mark_parallel(self, par_depth): for i in range(1, par_depth): loops[i].collapsed = True loops[0].simd = loops[par_depth - 1].simd - -class MLIRWrapperKenrelGroup(cpp.KernelGroup): - def __init__(self): - super().__init__() - self.args = mlir_common.MLIRKernelArgs() - -class MLIRScheduling(BaseScheduling): - count = 0 - target_kernel = MLIRKernel - def __init__(self, scheduler): - self.scheduler = scheduler - self.kernel_group = MLIRWrapperKenrelGroup() - self._ready_to_flush = False - self.outer_function = set() - config.inplace_buffers = False # FIXME. inout kernel makes trouble.. So disabled it! - - def _set_flush_status(self, status: bool): - self._ready_to_flush = status - - def can_fuse_vertical(self, node1, node2): - return False - return self.can_fuse_horizontal(node1, node2) and not node1.is_reduction() - - def can_fuse_horizontal(self, node1, node2): - return False - _, (vars1, reduce1) = node1.group - _, (vars2, reduce2) = node2.group - if vars1 == vars2 and reduce1 == reduce2: - return True - #TODO: Temporary solution determining the fusion condition similar to CPP/OpenMP - v1_total = math.prod(vars1) if len(vars1) else 0 - v2_total = math.prod(vars2) if len(vars2) else 0 - r1_total = math.prod(reduce1) if len(reduce1) else 0 - r2_total = math.prod(reduce2) if len(reduce2) else 0 - if reduce1 == () \ - and v1_total == (v2_total + r2_total): - # and node1.node.layout.size == node2.node.layout.size: #FIXME: Need to check layout too? - return True - return False - - def group_fn(self, sizes): - return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) - - def codegen_nodes(self, nodes): - _, (group, reduction_group) = max( - nodes, key=lambda x: int(x.is_reduction()) - ).group - ex_kernel = self.target_kernel() - ex_kernel.kernel_group = self.kernel_group - - kernel_name = f"extension_kernel_{MLIRScheduling.count}" - MLIRScheduling.count += 1 - src_code = ex_kernel.codegen_nodes(nodes, kernel_name) - self.define_kernel(src_code, kernel_name, ex_kernel.vector_lane, - ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) - ex_kernel.call_kernel(kernel_name) - _, args, _, _ = ex_kernel.args.mlir_argdefs() - args = ", ".join(args) - if (extension_config.CONFIG_BACKENDSIM_EAGER_MODE): - V.graph.wrapper_code.writeline( - f"yield ({kernel_name}, ({args}))" - ) - self._set_flush_status(True) - - def ready_to_flush(self): - return self._ready_to_flush - - def codegen_sync(self): - pass - - def flush(self): - self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) - self.kernel_group = MLIRWrapperKenrelGroup() - self._set_flush_status(False) - - def define_function(self, kernel): - code, function_name = kernel.def_function() - if code is not None and function_name not in self.outer_function: - wrapper = V.graph.wrapper_code - wrapper.header.writeline(code) - self.outer_function.add(function_name) - - def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, tile_size=[1, 1, 1], loop_size=None, origins={}): - wrapper = V.graph.wrapper_code - if src_code in wrapper.src_to_kernel: - kernel_name = wrapper.src_to_kernel[src_code] - else: - wrapper.src_to_kernel[src_code] = kernel_name - - codecache_def = IndentedBuffer() - codecache_def.writeline(f"custom_async_compile.mlir('''{src_code}''', ") - codecache_def.writeline(f"vectorlane_size={vector_lane},") - codecache_def.writeline(f"tile_size={tile_size},") - codecache_def.writeline(f"loop_size={loop_size},") - codecache_def.writeline(f"spad_info={spad_info},") - codecache_def.writeline(f"origins={origins},") - codecache_def.writeline("arg_attributes=arg_attributes)") - wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) - return kernel_name - - def codegen_src_code(self, kernel, render, template_node, epilogue_nodes): - with kernel: - for node in [template_node, *epilogue_nodes]: - node.mark_run() - partial_code = render() - for node in epilogue_nodes: - ranges = node.get_ranges() - node.codegen(kernel.set_ranges(ranges[0], ranges[1], None)) - with V.set_kernel_handler(kernel): - src_code = ( - partial_code - if isinstance(partial_code, str) - else partial_code.finalize() - ) - src_code = kernel.add_extra_global_vars(src_code) - return src_code - - def codegen_template(self, template_node, epilogue_nodes): - _, (numel, rnumel) = template_node.group - template_buffer = template_node.node - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, epilogue_nodes=epilogue_nodes) - _, _, _, kernel.buffer_types = kernel.args.mlir_argdefs() - - src_code = self.codegen_src_code(kernel, render, template_node, epilogue_nodes) - wrapper = V.graph.wrapper_code - - if src_code in wrapper.src_to_kernel: # [CONV] check inner function is already defined - kernel_name = wrapper.src_to_kernel[src_code] - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, epilogue_nodes=epilogue_nodes, kernel_name=kernel_name) # update kernel name - src_code = self.codegen_src_code(kernel, render, template_node, epilogue_nodes) - - with V.set_kernel_handler(kernel): - codegen_header(src_code, (kernel.header.getvalue(), kernel.gem5_header.getvalue())) - # node_schedule = [template_node, *epilogue_nodes] - kernel.meta_kernel() - kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, - kernel.tile_size, kernel.loop_size, origins={str(i) for i in template_node.node.origins}) - self.define_function(kernel) - - kernel.call_kernel(kernel_name) - V.graph.removed_buffers |= kernel.removed_buffers - _, args, _, _ = kernel.args.mlir_argdefs() - args = ", ".join(args) - if (extension_config.CONFIG_BACKENDSIM_EAGER_MODE): - target_kernel_name = kernel_name if kernel.outer_func_name is None else kernel.outer_func_name - V.graph.wrapper_code.writeline( - f"yield ({target_kernel_name}, ({args}))" - ) - self._set_flush_status(True) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py new file mode 100644 index 00000000..ea2005a8 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -0,0 +1,161 @@ +from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel + + +from torch._inductor import config +from torch._inductor.codegen import cpp +from torch._inductor.scheduler import BaseScheduling +from torch._inductor.utils import IndentedBuffer +from torch._inductor.virtualized import V + +from . import mlir_common +from . import mlir_lowering + +class MLIRWrapperKenrelGroup(cpp.KernelGroup): + def __init__(self): + super().__init__() + self.args = mlir_common.MLIRKernelArgs() + +class MLIRScheduling(BaseScheduling): + count = 0 + target_kernel = MLIRKernel + def __init__(self, scheduler): + self.scheduler = scheduler + self.kernel_group = MLIRWrapperKenrelGroup() + self._ready_to_flush = False + self.outer_function = set() + config.inplace_buffers = False # FIXME. inout kernel makes trouble.. So disabled it! + + def _set_flush_status(self, status: bool): + self._ready_to_flush = status + + def can_fuse_vertical(self, node1, node2): + return False + return self.can_fuse_horizontal(node1, node2) and not node1.is_reduction() + + def can_fuse_horizontal(self, node1, node2): + return False + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + if vars1 == vars2 and reduce1 == reduce2: + return True + #TODO: Temporary solution determining the fusion condition similar to CPP/OpenMP + v1_total = math.prod(vars1) if len(vars1) else 0 + v2_total = math.prod(vars2) if len(vars2) else 0 + r1_total = math.prod(reduce1) if len(reduce1) else 0 + r2_total = math.prod(reduce2) if len(reduce2) else 0 + if reduce1 == () \ + and v1_total == (v2_total + r2_total): + # and node1.node.layout.size == node2.node.layout.size: #FIXME: Need to check layout too? + return True + return False + + def group_fn(self, sizes): + return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) + + def codegen_nodes(self, nodes): + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + ex_kernel = self.target_kernel() + ex_kernel.kernel_group = self.kernel_group + + kernel_name = f"extension_kernel_{MLIRScheduling.count}" + MLIRScheduling.count += 1 + src_code = ex_kernel.codegen_nodes(nodes, kernel_name) + self.define_kernel(src_code, kernel_name, ex_kernel.vector_lane, + ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) + ex_kernel.call_kernel(kernel_name) + _, args, _, _ = ex_kernel.args.mlir_argdefs() + args = ", ".join(args) + if (extension_config.CONFIG_BACKENDSIM_EAGER_MODE): + V.graph.wrapper_code.writeline( + f"yield ({kernel_name}, ({args}))" + ) + self._set_flush_status(True) + + def ready_to_flush(self): + return self._ready_to_flush + + def codegen_sync(self): + pass + + def flush(self): + self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) + self.kernel_group = MLIRWrapperKenrelGroup() + self._set_flush_status(False) + + def define_function(self, kernel): + code, function_name = kernel.def_function() + if code is not None and function_name not in self.outer_function: + wrapper = V.graph.wrapper_code + wrapper.header.writeline(code) + self.outer_function.add(function_name) + + def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, tile_size=[1, 1, 1], loop_size=None, origins={}): + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + wrapper.src_to_kernel[src_code] = kernel_name + + codecache_def = IndentedBuffer() + codecache_def.writeline(f"custom_async_compile.mlir('''{src_code}''', ") + codecache_def.writeline(f"vectorlane_size={vector_lane},") + codecache_def.writeline(f"tile_size={tile_size},") + codecache_def.writeline(f"loop_size={loop_size},") + codecache_def.writeline(f"spad_info={spad_info},") + codecache_def.writeline(f"origins={origins},") + codecache_def.writeline("arg_attributes=arg_attributes)") + wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) + return kernel_name + + def codegen_src_code(self, kernel, render, template_node, epilogue_nodes): + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + partial_code = render() + for node in epilogue_nodes: + ranges = node.get_ranges() + node.codegen(kernel.set_ranges(ranges[0], ranges[1], None)) + with V.set_kernel_handler(kernel): + src_code = ( + partial_code + if isinstance(partial_code, str) + else partial_code.finalize() + ) + src_code = kernel.add_extra_global_vars(src_code) + return src_code + + def codegen_template(self, template_node, epilogue_nodes): + _, (numel, rnumel) = template_node.group + template_buffer = template_node.node + kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, epilogue_nodes=epilogue_nodes) + _, _, _, kernel.buffer_types = kernel.args.mlir_argdefs() + + src_code = self.codegen_src_code(kernel, render, template_node, epilogue_nodes) + wrapper = V.graph.wrapper_code + + if src_code in wrapper.src_to_kernel: # [CONV] check inner function is already defined + kernel_name = wrapper.src_to_kernel[src_code] + kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, epilogue_nodes=epilogue_nodes, kernel_name=kernel_name) # update kernel name + src_code = self.codegen_src_code(kernel, render, template_node, epilogue_nodes) + + with V.set_kernel_handler(kernel): + codegen_header(src_code, (kernel.header.getvalue(), kernel.gem5_header.getvalue())) + # node_schedule = [template_node, *epilogue_nodes] + kernel.meta_kernel() + kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, + kernel.tile_size, kernel.loop_size, origins={str(i) for i in template_node.node.origins}) + self.define_function(kernel) + + kernel.call_kernel(kernel_name) + V.graph.removed_buffers |= kernel.removed_buffers + _, args, _, _ = kernel.args.mlir_argdefs() + args = ", ".join(args) + if (extension_config.CONFIG_BACKENDSIM_EAGER_MODE): + target_kernel_name = kernel_name if kernel.outer_func_name is None else kernel.outer_func_name + V.graph.wrapper_code.writeline( + f"yield ({target_kernel_name}, ({args}))" + ) + self._set_flush_status(True) \ No newline at end of file diff --git a/Scheduler/scheduler.py b/Scheduler/scheduler.py index 1d8064f9..f1f0b35a 100644 --- a/Scheduler/scheduler.py +++ b/Scheduler/scheduler.py @@ -179,9 +179,11 @@ def setup_device(): register_backend_for_device, ) from PyTorchSimFrontend.mlir.mlir_codegen_backend import ( - MLIRScheduling, ExtensionWrapperCodegen, ) + from PyTorchSimFrontend.mlir.mlir_scheduling import ( + MLIRScheduling + ) register_backend_for_device( "extension_device", MLIRScheduling, ExtensionWrapperCodegen ) From 13cc93361aad8a29cb92f274d93399d8b1bad380 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 10 Jan 2025 04:27:08 +0000 Subject: [PATCH 2/4] [Frontend] move common function to mlir_common module --- .../mlir/mlir_codegen_backend.py | 35 ------------------ PyTorchSimFrontend/mlir/mlir_common.py | 36 +++++++++++++++++++ 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index d5336fdb..42e091c6 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -662,7 +662,6 @@ def __init__(self): self.kernel_group = None self.call_ranges = None self.ranges = None - self.itervars = None self.reduction_depth = None self.reduction_prefix = IndentedBuffer() self.reduction_suffix = IndentedBuffer() @@ -688,40 +687,6 @@ def __init__(self): self.reduce_iterator = {} self.is_template_kernel = False - def get_constant_vector(self, expr): - constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] - return constant_vector - - def get_constant_vector2(self, expr): - # Case 0. symbol ex) index 0 - # Case 1. inner product form ex) 16 * index0 + 1 * index1 - # Case 2. Complicated form ex) 16 * index0 + 8 * (index//4) + (index % 4) - constant_vector = [] - if expr.is_symbol: - constant_vector.append(tuple([1, expr])) - return constant_vector - - for arg in expr.args: - if arg.is_symbol: - constant_vector.append(tuple([1,arg])) - continue - if len(arg.args) == 0: #TODO: check this - continue - if arg.args[0].is_number: - constant_vector.append(arg.args) - else: - constant_vector.append([1, arg]) - - return constant_vector - - def find_node_by_name(self, name): - if name in V.graph.graph_inputs: - return V.graph.graph_inputs[name] - else: - for output_node in V.graph.graph_outputs: - if output_node.data.name == name: - return output_node - def get_dma_info(self, name, index, dtype): current_tile = MLIRTile(self.tile_desc.n_row, self.tile_desc.n_col, self.tile_desc.vector_lane, self.tile_desc.used_vector_lane) cv = self.get_constant_vector(index) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 912704b5..f0740910 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -163,6 +163,8 @@ class BaseMLIRKernel(common.Kernel, BaseMLIRHardwareInfo): def __init__(self, args=None): super().__init__(args) + self.itervars = None + self.vector_compute = IndentedBuffer() self.reductions_suffix = IndentedBuffer() self.cse = common.CSE(self.newvar_prefix, self.suffix) @@ -193,6 +195,40 @@ def check_dtype_in_args(self, args): dtype = arg return dtype + def get_constant_vector(self, expr): + constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] + return constant_vector + + def get_constant_vector2(self, expr): + # Case 0. symbol ex) index 0 + # Case 1. inner product form ex) 16 * index0 + 1 * index1 + # Case 2. Complicated form ex) 16 * index0 + 8 * (index//4) + (index % 4) + constant_vector = [] + if expr.is_symbol: + constant_vector.append(tuple([1, expr])) + return constant_vector + + for arg in expr.args: + if arg.is_symbol: + constant_vector.append(tuple([1,arg])) + continue + if len(arg.args) == 0: #TODO: check this + continue + if arg.args[0].is_number: + constant_vector.append(arg.args) + else: + constant_vector.append([1, arg]) + + return constant_vector + + def find_node_by_name(self, name): + if name in V.graph.graph_inputs: + return V.graph.graph_inputs[name] + else: + for output_node in V.graph.graph_outputs: + if output_node.data.name == name: + return output_node + def register_var_info(self, var, var_info): self.var_info[var] = var_info From efed9e8d6a929f85ad0916f6e3d9633cb689df6f Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 10 Jan 2025 04:50:16 +0000 Subject: [PATCH 3/4] [Frontend] move common logic to mlir_common --- .../mlir/mlir_codegen_backend.py | 81 +------------- PyTorchSimFrontend/mlir/mlir_common.py | 101 +++++++++++++++--- 2 files changed, 89 insertions(+), 93 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 42e091c6..25eadbbe 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -604,65 +604,12 @@ def broadcast(operand1, operand2, *args, var_info=None): "MVOUT": 3, } -class MLIRTile(): - TILE_ROW_WISE = 0 - TILE_COL_WISE = 1 - TILE_PER_LANE_ROW_WISE = 2 - TILE_PER_LANE_COL_WISE = 3 - def __init__(self, n_row, n_col, vector_lane, used_vector_lane=None) -> None: - self.n_row = n_row - self.n_col = n_col - self.vector_lane = vector_lane - if used_vector_lane is None: - self.used_vector_lane = self.vector_lane - else: - self.used_vector_lane = used_vector_lane - self.tile_per_lane_layout = self.TILE_PER_LANE_ROW_WISE # How a given tile per lane is stored - self.tile_layout = self.TILE_ROW_WISE # How a given tile is stored per lane - self.vector_lane_axis = (self.n_col//self.used_vector_lane) > 0 #(0: Col major, 1: Row major) - - def get_tile_size(self): - return self.n_row * self.n_col - - def get_rows_per_lane(self): - if self.n_row % self.used_vector_lane != 0 and self.n_row > 1: - print(f"[Warning] n_row({self.n_row}) % vector_lane({self.used_vector_lane}) != 0") - return self.div_round_up(self.n_row, self.used_vector_lane) - - def get_cols_per_lane(self): - if self.n_col % self.used_vector_lane != 0 and self.n_col > 1: - print(f"[Warning] n_col({self.n_col}) % vector_lane({self.used_vector_lane}) != 0") - return self.div_round_up(self.n_col, self.used_vector_lane) - - def get_tile_size_per_lane(self): - if self.get_tile_size() % self.used_vector_lane != 0: - print(f"[Warning] n_col({self.n_col}) % vector_lane({self.used_vector_lane}) != 0") - return self.div_round_up(self.get_tile_size(), self.used_vector_lane) - - def get_tile_shape(self): - return f"{self.n_row}x{self.n_col}" - - def get_chunk_size(self): - if self.tile_layout == self.TILE_ROW_WISE: - chunk_size = self.get_tile_size_per_lane() - else: - chunk_size = self.get_cols_per_lane() - return chunk_size - - @staticmethod - def div_round_up(size, round_val): - return (size + round_val - 1) // round_val - class MLIRKernel(mlir_common.BaseMLIRKernel): overrides = ExtensionOverrides newvar_prefix = "%" def __init__(self): super().__init__(mlir_common.MLIRKernelArgs()) - self.kernel_group = None - self.call_ranges = None - self.ranges = None - self.reduction_depth = None self.reduction_prefix = IndentedBuffer() self.reduction_suffix = IndentedBuffer() self.body = IndentedBuffer() @@ -678,7 +625,6 @@ def __init__(self): self.map_cse = common.CSE("#", self.suffix, name_prefix="map") self.consts = set() self.tags = set() - self.tile_desc = MLIRTile(self.tile_row, self.tile_col, self.vector_lane) self.dma_cache = {} self.dma_counter = 1 self.reduction_idx = {} @@ -1067,12 +1013,6 @@ def store_reduction(self, name, index, value): self.cse.generate(self.reductions_suffix, code, assignment = False) def codegen_body(self): - # if not ( - # self.loads - # or self.stores - # or self.compute - # ): - # return def template_store(options): subtile_size = [self.vector_lane, self.vector_lane] async_flag = 1 @@ -1228,25 +1168,11 @@ def adjust_tile_size(self): raise NotImplementedError() def set_ranges(self, lengths, reduction_lengths, read_writes): - self.read_writes = read_writes - if self.call_ranges: - assert self.call_ranges == tuple(lengths) + tuple( - reduction_lengths - ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" - assert self.reduction_depth == len(lengths) - else: - self.call_ranges = tuple(lengths) + tuple(reduction_lengths) - self.ranges = [self.rename_indexing(x) for x in self.call_ranges] - self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))] - self.reduction_depth = len(lengths) + ret = super().set_ranges(lengths, reduction_lengths, read_writes) # Adjust time size when it is vector self.adjust_tile_size() - - return ( - self.itervars[: self.reduction_depth], - self.itervars[self.reduction_depth :], - ) + return ret def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index): c_type = mlir_common.DTYPE_TO_C[dtype] @@ -1276,8 +1202,7 @@ def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape buffer = self.cse.generate(code_buffer, f"memref.get_global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") return buffer, indices - def roundup_vectorlane(self, size, amp=1): - return ((size + self.vector_lane - 1) // self.vector_lane) * self.vector_lane * amp + @dataclasses.dataclass class LoopLevel: diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index f0740910..df44a810 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -141,6 +141,56 @@ def set_info(outer, inner, arg_type): set_info(outer, inner, self.MLIR_ARGS_VAR) return arg_defs, call_args, arg_attributes, buffer_types + +class MLIRTile(): + TILE_ROW_WISE = 0 + TILE_COL_WISE = 1 + TILE_PER_LANE_ROW_WISE = 2 + TILE_PER_LANE_COL_WISE = 3 + def __init__(self, n_row, n_col, vector_lane, used_vector_lane=None) -> None: + self.n_row = n_row + self.n_col = n_col + self.vector_lane = vector_lane + if used_vector_lane is None: + self.used_vector_lane = self.vector_lane + else: + self.used_vector_lane = used_vector_lane + self.tile_per_lane_layout = self.TILE_PER_LANE_ROW_WISE # How a given tile per lane is stored + self.tile_layout = self.TILE_ROW_WISE # How a given tile is stored per lane + self.vector_lane_axis = (self.n_col//self.used_vector_lane) > 0 #(0: Col major, 1: Row major) + + def get_tile_size(self): + return self.n_row * self.n_col + + def get_rows_per_lane(self): + if self.n_row % self.used_vector_lane != 0 and self.n_row > 1: + print(f"[Warning] n_row({self.n_row}) % vector_lane({self.used_vector_lane}) != 0") + return self.div_round_up(self.n_row, self.used_vector_lane) + + def get_cols_per_lane(self): + if self.n_col % self.used_vector_lane != 0 and self.n_col > 1: + print(f"[Warning] n_col({self.n_col}) % vector_lane({self.used_vector_lane}) != 0") + return self.div_round_up(self.n_col, self.used_vector_lane) + + def get_tile_size_per_lane(self): + if self.get_tile_size() % self.used_vector_lane != 0: + print(f"[Warning] n_col({self.n_col}) % vector_lane({self.used_vector_lane}) != 0") + return self.div_round_up(self.get_tile_size(), self.used_vector_lane) + + def get_tile_shape(self): + return f"{self.n_row}x{self.n_col}" + + def get_chunk_size(self): + if self.tile_layout == self.TILE_ROW_WISE: + chunk_size = self.get_tile_size_per_lane() + else: + chunk_size = self.get_cols_per_lane() + return chunk_size + + @staticmethod + def div_round_up(size, round_val): + return (size + round_val - 1) // round_val + class BaseMLIRHardwareInfo(): def __init__(self): # Default HW setting @@ -163,18 +213,43 @@ class BaseMLIRKernel(common.Kernel, BaseMLIRHardwareInfo): def __init__(self, args=None): super().__init__(args) + self.kernel_group = None + # Kernel iteration range info + self.call_ranges = None + self.ranges = None + self.reduction_depth = None self.itervars = None - + # Code buffer self.vector_compute = IndentedBuffer() self.reductions_suffix = IndentedBuffer() self.cse = common.CSE(self.newvar_prefix, self.suffix) - self.tile_row = extension_config.CONFIG_TILE_ROW - if self.tile_row == -1: - self.tile_row = self.vlen * self.vector_lane - self.tile_col = extension_config.CONFIG_TILE_COL - if self.tile_col == -1: - self.tile_col = 8 # FIXME: tile_col is not always vector_lane * vlen - self.var_info = {} + # Tile size setting + tile_row = extension_config.CONFIG_TILE_ROW + if tile_row == -1: + tile_row = self.vlen * self.vector_lane + tile_col = extension_config.CONFIG_TILE_COL + if tile_col == -1: + tile_col = 8 # FIXME: tile_col is not always vector_lane * vlen + self.tile_desc = MLIRTile(tile_row, tile_col, self.vector_lane) + self.var_info = {} # MLIR variable info + + def set_ranges(self, lengths, reduction_lengths, read_writes): + self.read_writes = read_writes + if self.call_ranges: + assert self.call_ranges == tuple(lengths) + tuple( + reduction_lengths + ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + assert self.reduction_depth == len(lengths) + else: + self.call_ranges = tuple(lengths) + tuple(reduction_lengths) + self.ranges = [self.rename_indexing(x) for x in self.call_ranges] + self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))] + self.reduction_depth = len(lengths) + + return ( + self.itervars[: self.reduction_depth], + self.itervars[self.reduction_depth :], + ) def load(self, name: str, index: sympy.Expr): raise NotImplementedError() @@ -188,13 +263,6 @@ def store(self, name, index, value, mode=None): def reduction(self, dtype, src_dtype, reduction_type, value): raise NotImplementedError() - def check_dtype_in_args(self, args): - dtype = torch.float32 # default dtype - for arg in args: - if arg in list(DTYPE_TO_MLIR.keys()): - dtype = arg - return dtype - def get_constant_vector(self, expr): constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] return constant_vector @@ -229,6 +297,9 @@ def find_node_by_name(self, name): if output_node.data.name == name: return output_node + def roundup_vectorlane(self, size, amp=1): + return ((size + self.vector_lane - 1) // self.vector_lane) * self.vector_lane * amp + def register_var_info(self, var, var_info): self.var_info[var] = var_info From 4b28d4a28758bcbd035809012d77ea862957fa1a Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 10 Jan 2025 05:21:50 +0000 Subject: [PATCH 4/4] [Frontend] Cleanup mlir_codegen_backend module --- .../mlir/mlir_codegen_backend.py | 312 ++++++++---------- PyTorchSimFrontend/mlir/mlir_common.py | 102 +++++- PyTorchSimFrontend/mlir/mlir_scheduling.py | 12 +- PyTorchSimFrontend/mlir/mlir_template.py | 4 +- 4 files changed, 222 insertions(+), 208 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 25eadbbe..99c39322 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -627,116 +627,17 @@ def __init__(self): self.tags = set() self.dma_cache = {} self.dma_counter = 1 - self.reduction_idx = {} self.affine_yield = {} self.welford_reduce_out = None self.reduce_iterator = {} self.is_template_kernel = False - def get_dma_info(self, name, index, dtype): - current_tile = MLIRTile(self.tile_desc.n_row, self.tile_desc.n_col, self.tile_desc.vector_lane, self.tile_desc.used_vector_lane) - cv = self.get_constant_vector(index) - cv2 = self.get_constant_vector2(index) - tile_size_per_lane = self.tile_desc.get_tile_size_per_lane() # FIXME. move this - tile_size_per_lane = 2 if tile_size_per_lane==1 else tile_size_per_lane # Avoid scalar operation - - if len(cv) != len(cv2) and len(cv2) == 3: - print("Mismatch! ", cv) - # FIXME. this is really shitty code :( - cv = cv2#[[1 if x[0] == 0 else x[0], x[1]] for x in cv] - - # Case 0. Tile is 0-D scalar - if len(cv) == 0: - # Use only one vectorlane to handle scalar data - current_tile.n_row = 1 - current_tile.n_col = 1 - current_tile.tile_layout = MLIRTile.TILE_ROW_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_ROW_WISE - mm_stride, tile_size_per_lane = 1, 1 - chunk_size = current_tile.get_chunk_size() - # Case 1. Tile is 1-D vector type - elif len(cv) == 1 and len(cv) <= self.reduction_depth: - current_tile.n_row = 1 - current_tile.n_col = self.tile_desc.get_tile_size() - current_tile.tile_layout = MLIRTile.TILE_ROW_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_COL_WISE # Actually it is not needed in vector case - chunk_size = current_tile.get_chunk_size() - mm_stride = current_tile.n_col - # Case 2. Tile is 1-D vector type with reduction - elif len(cv) == 1 and len(cv) == self.reduction_depth + 1: - # Use only one vectorlane to reduce a vector - current_tile.tile_layout = MLIRTile.TILE_ROW_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_ROW_WISE - current_tile.n_row = 1 - current_tile.n_col = self.tile_desc.get_tile_size() - current_tile.used_vector_lane = 1 - chunk_size = current_tile.get_chunk_size() - mm_stride = 0 # don't care - # Case 3. Tile is 2-D tile - elif len(cv) == 2: - is_reduction = self.reduction_depth == 1 - if cv[0][0] != 0 and cv[1][0] != 0: - is_transposed = cv[0][0] < cv[1][0] - if is_transposed: - current_tile.n_row = self.tile_desc.n_col - current_tile.n_col = self.tile_desc.n_row - mm_stride = self.ranges[0] - else: - current_tile.n_row = self.tile_desc.n_row - current_tile.n_col = self.tile_desc.n_col - mm_stride = self.ranges[1] - - if is_reduction and is_transposed: - current_tile.tile_layout = MLIRTile.TILE_COL_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_ROW_WISE - chunk_size = current_tile.get_chunk_size() - elif is_reduction and not is_transposed: - current_tile.tile_layout = MLIRTile.TILE_ROW_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_COL_WISE - chunk_size = current_tile.get_chunk_size() - elif not is_reduction and is_transposed: - # Transposed case - current_tile.tile_layout = MLIRTile.TILE_COL_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_COL_WISE - chunk_size = current_tile.get_chunk_size() - else: # not is_reduction and not is_transpose - current_tile.tile_layout = MLIRTile.TILE_COL_WISE if self.tile_desc.vector_lane_axis else MLIRTile.TILE_ROW_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_ROW_WISE - chunk_size = current_tile.get_chunk_size() - else: - # Broadcast pattern - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_ROW_WISE - mm_stride = 0 - if cv[0][0] == 0: - current_tile.tile_layout = MLIRTile.TILE_COL_WISE if self.tile_desc.vector_lane_axis else MLIRTile.TILE_ROW_WISE - current_tile.n_row = self.tile_desc.n_row - current_tile.n_col = self.tile_desc.n_col - chunk_size = current_tile.get_chunk_size() - else: # cv[1][0] == 0 - current_tile.n_row = self.tile_desc.n_col - current_tile.n_col = self.tile_desc.n_row - chunk_size = current_tile.get_cols_per_lane() - if not is_reduction: - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_COL_WISE - chunk_size = current_tile.n_col if self.tile_desc.vector_lane_axis else chunk_size - elif len(cv) == 3: - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_COL_WISE # Actually it is not needed in vector case - mm_stride = cv[-1][0] - # When current_tile.n_col stride is 1, we can access row vector - if mm_stride == 1: - current_tile.n_row = 1 - current_tile.n_col = self.tile_desc.get_tile_size() - # if current_tile.n_col stride is not 1, we have to access in a column vector - else: - current_tile.n_row = self.tile_desc.get_tile_size() - current_tile.n_col = 1 - chunk_size = current_tile.get_tile_size_per_lane() - else: - raise NotImplementedError() + def set_ranges(self, lengths, reduction_lengths, read_writes): + ret = super().set_ranges(lengths, reduction_lengths, read_writes) - #assert(not (dtype==torch.bool and chunk_size < 8)) - chunk = chunk_size << 1 | (current_tile.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE) - return mm_stride, chunk, [current_tile.n_row, current_tile.n_col], tile_size_per_lane + # Adjust time size when it is vector + self.adjust_tile_size() + return ret def parse_indices(self, expr): if len(expr.args) == 0: @@ -766,35 +667,6 @@ def parse_indices(self, expr): index = self.cse.generate(self.loads, f"affine.apply #{map_var}({args})") return index - def codegen_nodes(self, nodes, kernel_name): - _, (group, reduction_group) = max( - nodes, key=lambda x: int(x.is_reduction()) - ).group - - self.set_ranges(group, reduction_group, None) - with self as kernel: - kernel.args = kernel.kernel_group.args - for node in nodes: - vars, reduction_vars = kernel.set_ranges(group, reduction_group, node.read_writes) - kernel.args.tile_row = kernel.tile_desc.n_row - kernel.args.tile_col = kernel.tile_desc.n_col - _, _, _, kernel.buffer_types = kernel.args.mlir_argdefs() - kernel.reduction_idx = {var: i for i, var in enumerate(reduction_vars)} - node.run(vars, reduction_vars) - src_code = self.codegen_kernel(kernel_name=kernel_name) - self.meta_kernel() - - write_path = extension_codecache.get_write_path(src_code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header.getvalue()) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header.getvalue()) - return src_code - def load(self, name: str, index: sympy.Expr): index = self.rename_indexing(index) indices = self.parse_indices(index) @@ -1000,9 +872,9 @@ def store_reduction(self, name, index, value): # MVOUT Encoding dmaType = 3 # MVIN 2, MVIN2 1, MVIN3 14, MVOUT 3 mm_stride = tile_col - is_col_major = MLIRTile.TILE_PER_LANE_ROW_WISE + is_col_major = mlir_common.MLIRTile.TILE_PER_LANE_ROW_WISE chunk_size = self.tile_desc.get_rows_per_lane() - chunk = chunk_size << 1 | (is_col_major == MLIRTile.TILE_PER_LANE_COL_WISE) + chunk = chunk_size << 1 | (is_col_major == mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE) self.consts.add(dmaType) self.consts.add(mm_stride) self.consts.add(chunk) @@ -1031,6 +903,9 @@ def template_store(options): self.compute.clear() self.stores.clear() + def codegen_global_init(self): + return self.global_vars + def codegen_init(self): code = IndentedBuffer() tags = sorted(self.tags) @@ -1046,8 +921,6 @@ def codegen_loops(self): # Loop body part tile_row, tile_col = self.tile_desc.n_row, self.tile_desc.n_col # FIXME. - #if (self.tiling_idx < self.reduction_depth and len(self.reduction_idx) > 0): - # tile_row, tile_col = self.tile_desc.n_col, self.tile_desc.n_row tile_row = self.tile_desc.get_tile_size() if len(self.itervars) == 1 else tile_row loops = [LoopLevel(var, size, idx-len(self.itervars), tile_row=tile_row, tile_col=tile_col) for idx, (var, size) in enumerate(zip(self.itervars, self.ranges))] loops, reductions = [LoopNest(loops[: self.reduction_depth]), @@ -1082,43 +955,125 @@ def codegen_loops(self): code.writeline(f"return") return code - def codegen_kernel(self, kernel_name): - wrapper = V.graph.wrapper_code - arg_defs, _, _, _ = self.kernel_group.args.mlir_argdefs() - code = self._codegen_kernel(arg_defs, kernel_name) - return code.getvalue() - - def meta_kernel(self): - wrapper = V.graph.wrapper_code - _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() - wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - wrapper.add_import_once(f'\nfrom PyTorchSimFrontend.extension_codecache import CustomAsyncCompile') - wrapper.add_import_once(f'\ncustom_async_compile = CustomAsyncCompile()') - # Dump loop and load/store information - wrapper.add_import_once(f"arg_attributes = {arg_attributes}") - - - def call_kernel(self, kernel_name): - wrapper = V.graph.wrapper_code - _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() - # generate the code to call this - wrapper.generate_kernel_call(kernel_name, call_args, cuda=False) - - def _codegen_kernel(self, arg_defs, kernel_name): - arg_defs = ",\n".ljust(25).join(arg_defs) - code = common.BracesBuffer() - - code.splice(self.global_vars) - #TODO:. kernel name custom - kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" - code.writeline(f'func.func @{kernel_decl_name}({arg_defs})') - with code.indent(): - for old, new in self.kernel_group.args.aliases(): - code.writeline(f"auto {old} = {new};") - # Loop body part - code.splice(self.codegen_init()) - code.splice(self.codegen_loops()) - return code + def codegen_nodes(self, nodes, kernel_name): + src_code = super().codegen_nodes(nodes, kernel_name) + + # Create extra header for simulatoors + write_path = extension_codecache.get_write_path(src_code) + if not os.path.exists(write_path): + os.makedirs(write_path) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, self.header.getvalue()) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, self.gem5_header.getvalue()) + return src_code + + def get_dma_info(self, name, index, dtype): + current_tile = mlir_common.MLIRTile(self.tile_desc.n_row, self.tile_desc.n_col, self.tile_desc.vector_lane, self.tile_desc.used_vector_lane) + cv = self.get_constant_vector(index) + cv2 = self.get_constant_vector2(index) + tile_size_per_lane = self.tile_desc.get_tile_size_per_lane() # FIXME. move this + tile_size_per_lane = 2 if tile_size_per_lane==1 else tile_size_per_lane # Avoid scalar operation + + if len(cv) != len(cv2) and len(cv2) == 3: + print("Mismatch! ", cv) + # FIXME. this is really shitty code :( + cv = cv2#[[1 if x[0] == 0 else x[0], x[1]] for x in cv] + + # Case 0. Tile is 0-D scalar + if len(cv) == 0: + # Use only one vectorlane to handle scalar data + current_tile.n_row = 1 + current_tile.n_col = 1 + current_tile.tile_layout = mlir_common.MLIRTile.TILE_ROW_WISE + current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_ROW_WISE + mm_stride, tile_size_per_lane = 1, 1 + chunk_size = current_tile.get_chunk_size() + # Case 1. Tile is 1-D vector type + elif len(cv) == 1 and len(cv) <= self.reduction_depth: + current_tile.n_row = 1 + current_tile.n_col = self.tile_desc.get_tile_size() + current_tile.tile_layout = mlir_common.MLIRTile.TILE_ROW_WISE + current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE # Actually it is not needed in vector case + chunk_size = current_tile.get_chunk_size() + mm_stride = current_tile.n_col + # Case 2. Tile is 1-D vector type with reduction + elif len(cv) == 1 and len(cv) == self.reduction_depth + 1: + # Use only one vectorlane to reduce a vector + current_tile.tile_layout = mlir_common.MLIRTile.TILE_ROW_WISE + current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_ROW_WISE + current_tile.n_row = 1 + current_tile.n_col = self.tile_desc.get_tile_size() + current_tile.used_vector_lane = 1 + chunk_size = current_tile.get_chunk_size() + mm_stride = 0 # don't care + # Case 3. Tile is 2-D tile + elif len(cv) == 2: + is_reduction = self.reduction_depth == 1 + if cv[0][0] != 0 and cv[1][0] != 0: + is_transposed = cv[0][0] < cv[1][0] + if is_transposed: + current_tile.n_row = self.tile_desc.n_col + current_tile.n_col = self.tile_desc.n_row + mm_stride = self.ranges[0] + else: + current_tile.n_row = self.tile_desc.n_row + current_tile.n_col = self.tile_desc.n_col + mm_stride = self.ranges[1] + + if is_reduction and is_transposed: + current_tile.tile_layout = mlir_common.MLIRTile.TILE_COL_WISE + current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_ROW_WISE + chunk_size = current_tile.get_chunk_size() + elif is_reduction and not is_transposed: + current_tile.tile_layout = mlir_common.MLIRTile.TILE_ROW_WISE + current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE + chunk_size = current_tile.get_chunk_size() + elif not is_reduction and is_transposed: + # Transposed case + current_tile.tile_layout = mlir_common.MLIRTile.TILE_COL_WISE + current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE + chunk_size = current_tile.get_chunk_size() + else: # not is_reduction and not is_transpose + current_tile.tile_layout = mlir_common.MLIRTile.TILE_COL_WISE if self.tile_desc.vector_lane_axis else mlir_common.MLIRTile.TILE_ROW_WISE + current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_ROW_WISE + chunk_size = current_tile.get_chunk_size() + else: + # Broadcast pattern + current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_ROW_WISE + mm_stride = 0 + if cv[0][0] == 0: + current_tile.tile_layout = mlir_common.MLIRTile.TILE_COL_WISE if self.tile_desc.vector_lane_axis else mlir_common.MLIRTile.TILE_ROW_WISE + current_tile.n_row = self.tile_desc.n_row + current_tile.n_col = self.tile_desc.n_col + chunk_size = current_tile.get_chunk_size() + else: # cv[1][0] == 0 + current_tile.n_row = self.tile_desc.n_col + current_tile.n_col = self.tile_desc.n_row + chunk_size = current_tile.get_cols_per_lane() + if not is_reduction: + current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE + chunk_size = current_tile.n_col if self.tile_desc.vector_lane_axis else chunk_size + elif len(cv) == 3: + current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE # Actually it is not needed in vector case + mm_stride = cv[-1][0] + # When current_tile.n_col stride is 1, we can access row vector + if mm_stride == 1: + current_tile.n_row = 1 + current_tile.n_col = self.tile_desc.get_tile_size() + # if current_tile.n_col stride is not 1, we have to access in a column vector + else: + current_tile.n_row = self.tile_desc.get_tile_size() + current_tile.n_col = 1 + chunk_size = current_tile.get_tile_size_per_lane() + else: + raise NotImplementedError() + + #assert(not (dtype==torch.bool and chunk_size < 8)) + chunk = chunk_size << 1 | (current_tile.tile_per_lane_layout == mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE) + return mm_stride, chunk, [current_tile.n_row, current_tile.n_col], tile_size_per_lane def adjust_tile_size(self): if self.read_writes is not None: @@ -1167,13 +1122,6 @@ def adjust_tile_size(self): if len(self.itervars) >= 3 and self.reduction_depth < len(self.itervars): raise NotImplementedError() - def set_ranges(self, lengths, reduction_lengths, read_writes): - ret = super().set_ranges(lengths, reduction_lengths, read_writes) - - # Adjust time size when it is vector - self.adjust_tile_size() - return ret - def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index): c_type = mlir_common.DTYPE_TO_C[dtype] mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] @@ -1202,8 +1150,6 @@ def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape buffer = self.cse.generate(code_buffer, f"memref.get_global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") return buffer, indices - - @dataclasses.dataclass class LoopLevel: var: sympy.Expr diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index df44a810..a949cb5d 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -1,6 +1,7 @@ import os import torch from torch._inductor.codegen import common +from torch._inductor.codegen import cpp from torch._inductor.virtualized import V from torch._inductor.ir import MultiOutputLayout import sympy @@ -191,6 +192,11 @@ def get_chunk_size(self): def div_round_up(size, round_val): return (size + round_val - 1) // round_val +class MLIRWrapperKenrelGroup(cpp.KernelGroup): + def __init__(self): + super().__init__() + self.args = MLIRKernelArgs() + class BaseMLIRHardwareInfo(): def __init__(self): # Default HW setting @@ -213,7 +219,7 @@ class BaseMLIRKernel(common.Kernel, BaseMLIRHardwareInfo): def __init__(self, args=None): super().__init__(args) - self.kernel_group = None + self.kernel_group : MLIRWrapperKenrelGroup = None # Kernel iteration range info self.call_ranges = None self.ranges = None @@ -232,6 +238,8 @@ def __init__(self, args=None): tile_col = 8 # FIXME: tile_col is not always vector_lane * vlen self.tile_desc = MLIRTile(tile_row, tile_col, self.vector_lane) self.var_info = {} # MLIR variable info + self.buffer_types : dict = None + self.read_writes = None def set_ranges(self, lengths, reduction_lengths, read_writes): self.read_writes = read_writes @@ -263,6 +271,70 @@ def store(self, name, index, value, mode=None): def reduction(self, dtype, src_dtype, reduction_type, value): raise NotImplementedError() + def codegen_global_init(self): + raise NotImplementedError() + + def codegen_loops(self): + raise NotImplementedError() + + def codegen_init(self): + raise NotImplementedError() + + def call_kernel(self, kernel_name): + wrapper = V.graph.wrapper_code + _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() + # generate the code to call this + wrapper.generate_kernel_call(kernel_name, call_args, cuda=False) + + def codegen_nodes(self, nodes, kernel_name): + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + + self.set_ranges(group, reduction_group, None) + with self as kernel: + kernel.args = kernel.kernel_group.args + for node in nodes: + vars, reduction_vars = kernel.set_ranges(group, reduction_group, node.read_writes) + kernel.args.tile_row = kernel.tile_desc.n_row + kernel.args.tile_col = kernel.tile_desc.n_col + _, _, _, kernel.buffer_types = kernel.args.mlir_argdefs() + node.run(vars, reduction_vars) + src_code = self.codegen_kernel(kernel_name=kernel_name) + self.meta_kernel() + return src_code + + def codegen_kernel(self, kernel_name): + arg_defs, _, _, _ = self.kernel_group.args.mlir_argdefs() + code = self._codegen_kernel(arg_defs, kernel_name) + return code.getvalue() + + def _codegen_kernel(self, arg_defs, kernel_name): + arg_defs = ",\n".ljust(25).join(arg_defs) + code = common.BracesBuffer() + + #TODO:. kernel name custom + kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" + + code.splice(self.codegen_global_init()) + code.writeline(f'func.func @{kernel_decl_name}({arg_defs})') + with code.indent(): + for old, new in self.kernel_group.args.aliases(): + code.writeline(f"auto {old} = {new};") + # Loop body part + code.splice(self.codegen_init()) + code.splice(self.codegen_loops()) + return code + + def meta_kernel(self): + wrapper = V.graph.wrapper_code + _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() + wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') + wrapper.add_import_once(f'\nfrom PyTorchSimFrontend.extension_codecache import CustomAsyncCompile') + wrapper.add_import_once(f'\ncustom_async_compile = CustomAsyncCompile()') + # Dump loop and load/store information + wrapper.add_import_once(f"arg_attributes = {arg_attributes}") + def get_constant_vector(self, expr): constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] return constant_vector @@ -303,6 +375,20 @@ def roundup_vectorlane(self, size, amp=1): def register_var_info(self, var, var_info): self.var_info[var] = var_info + def rename_indexing(self, index) -> sympy.Expr: + # adds the necessary kernel args for index expressions + # and renames variables in index expressions to kernel arg names + if isinstance(index, (list, tuple)): + return [self.rename_indexing(x) for x in index] + index = V.graph.sizevars.simplify(index) + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) + replacements = { + x: self.args.size(x) + for x in sorted_symbols + if x.name.startswith("s") or x.name.startswith("ps") + } + return sympy_subs(index, replacements) + def __enter__(self): class CSEProxy: self.name = "CSEProxy" @@ -407,16 +493,4 @@ def bucketize( self.exit_stack.enter_context(V.set_kernel_handler(self)) return self - def rename_indexing(self, index) -> sympy.Expr: - # adds the necessary kernel args for index expressions - # and renames variables in index expressions to kernel arg names - if isinstance(index, (list, tuple)): - return [self.rename_indexing(x) for x in index] - index = V.graph.sizevars.simplify(index) - sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) - replacements = { - x: self.args.size(x) - for x in sorted_symbols - if x.name.startswith("s") or x.name.startswith("ps") - } - return sympy_subs(index, replacements) + diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index ea2005a8..752fa8b4 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -1,9 +1,8 @@ +import math from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel - from torch._inductor import config -from torch._inductor.codegen import cpp from torch._inductor.scheduler import BaseScheduling from torch._inductor.utils import IndentedBuffer from torch._inductor.virtualized import V @@ -11,17 +10,12 @@ from . import mlir_common from . import mlir_lowering -class MLIRWrapperKenrelGroup(cpp.KernelGroup): - def __init__(self): - super().__init__() - self.args = mlir_common.MLIRKernelArgs() - class MLIRScheduling(BaseScheduling): count = 0 target_kernel = MLIRKernel def __init__(self, scheduler): self.scheduler = scheduler - self.kernel_group = MLIRWrapperKenrelGroup() + self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() self._ready_to_flush = False self.outer_function = set() config.inplace_buffers = False # FIXME. inout kernel makes trouble.. So disabled it! @@ -82,7 +76,7 @@ def codegen_sync(self): def flush(self): self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) - self.kernel_group = MLIRWrapperKenrelGroup() + self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() self._set_flush_status(False) def define_function(self, kernel): diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index ec1340e7..f8e2b428 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -21,8 +21,8 @@ from torch._inductor.virtualized import V from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest -from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo -from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, MLIRTile +from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo, MLIRTile +from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel from . import mlir_common