diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 1e756f96..ca669361 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -299,7 +299,7 @@ def dummy_simulator(*args, **kwargs): # Dump arguments and meta data dump_metadata(args, arg_attributes, result_path) runtime_path = FunctionalSimulator.get_runtime_dump_path(result_path) - if extension_config.CONFIG_TORCHSIM_VALIDATION_MODE or validate: + if not autotune and (extension_config.CONFIG_TORCHSIM_VALIDATION_MODE or validate): funcsim = FunctionalSimulator(result_path, key) funcsim.run_spike(args, arg_attributes, runtime_path, self.validation_binary_name, diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 80675682..fa5d22b5 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -46,7 +46,9 @@ # AUTOTUNE config CONFIG_AUTOTUNE = int(os.environ.get('AUTOTUNE', default=True)) +CONFIG_AUTOTUNE_TEMPLATE = int(os.environ.get('AUTOTUNE_TEMPLATE', default=True)) CONFIG_MAX_AUTOTUNE_TRY = int(os.environ.get('MAX_AUTOTUNE_TRY', default=10)) +CONFIG_AUTOTUNE_TEMPLATE_TOPK = int(os.environ.get('AUTOTUNE_TEMPLATE_TOPK', default=4)) # For block sparse CONFIG_BLOCK_SPARSE = int(os.environ.get('BLOCK_SPARSE', default=0)) diff --git a/PyTorchSimFrontend/mlir/mlir_autotune.py b/PyTorchSimFrontend/mlir/mlir_autotune.py index 54aed9c0..537809de 100644 --- a/PyTorchSimFrontend/mlir/mlir_autotune.py +++ b/PyTorchSimFrontend/mlir/mlir_autotune.py @@ -74,7 +74,8 @@ def cached_run_fn(*args, **kwargs): self.source_code, vectorlane_size=self.extra_args["vector_lane"], loop_size=None, spad_info=self.extra_args["spad_info"], vlen=self.extra_args["vlen"], arg_attributes=self.extra_args["arg_attributes"], - origins="Unknown", silent_mode=True) + origins="Unknown", silent_mode=True, + validate=self.extra_args['validate'], autotune=self.extra_args['autotune']) args = [ tensor diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 79e03bd5..178ea987 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -6,8 +6,6 @@ from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common BMM_TEMPLATE = r""" @@ -162,51 +160,31 @@ def render(self, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): - if template_buffer_node is not None: - self.output_node = template_buffer_node - - # Extract input arguments info - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - W_tensor = empty_strided(W.layout.size, W.layout.stride) - X_tensor = empty_strided(X.layout.size, X.layout.stride) - if len(W_tensor.size()) > 3 or len(W_tensor.size()) == 2: - W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]]) - if len(X_tensor.size()) > 3 or len(X_tensor.size()) == 2: - X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]]) - B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2] - - W_stride = W_tensor.stride() - X_stride = X_tensor.stride() - - # Select tile size - n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) - SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or prologue_nodes else kernel.vector_lane - SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane - SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + if tile_info is None: + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node)[0] + else: + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info TOG_latency = M if TILE_M > M else TILE_M kernel.loop_size = [TOG_latency, TILE_N, TILE_K] - TILE_K = TILE_K // 2 if prologue_nodes else TILE_K # Select template code nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] if nr_reduction_nodes: - template = BMM_REDUCTION_TEMPLATE - epilogue_dim_aliasing = {"index0":"index0", "index1":"index2", "index2": "index1"} - nr_rdim = 1 + template = BMM_REDUCTION_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index2", "index2": "index1"} + nr_rdim = 1 elif prologue_nodes: - template = BMM_PROLOGUE_TEMPLATE - epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} - nr_rdim = 0 + template = BMM_PROLOGUE_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} + nr_rdim = 0 else: - template = BMM_TEMPLATE - epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} - nr_rdim = 0 + template = BMM_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} + nr_rdim = 0 # Prepare tile descriptors vlane_stride = 1 @@ -323,19 +301,53 @@ def render(self, dram_idx = Y_idx, dram_tile_desc = Y_tile_desc, nr_rdim = nr_rdim, + r_dim_size = M, dim_aliasing = epilogue_dim_aliasing ) code = self._template_from_string(template).render(**kernel.render_options) kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) \ No newline at end of file + def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): + if template_buffer_node is not None: + self.output_node = template_buffer_node + + # Extract input arguments info + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + W_tensor = empty_strided(W.layout.size, W.layout.stride) + X_tensor = empty_strided(X.layout.size, X.layout.stride) + if len(W_tensor.size()) > 3 or len(W_tensor.size()) == 2: + W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]]) + if len(X_tensor.size()) > 3 or len(X_tensor.size()) == 2: + X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]]) + B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2] + + W_stride = W_tensor.stride() + X_stride = X_tensor.stride() + + # Select tile size + n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 + return X,W,Y,Bias,W_tensor,X_tensor,B,M,N,K,n_extra_node, n_prologue_node + + def get_tile_candidates(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + **kwargs): + X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + return self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node) + + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): + tile_candidates = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) + for idx, (TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or n_prologue_node else kernel.vector_lane + SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + TILE_K = TILE_K // 2 if n_prologue_node else TILE_K + tile_candidates[idx] = TILE_M,TILE_N,TILE_K,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 09ee129b..b3352ea6 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -17,8 +17,7 @@ sympy_product ) from torch.utils._sympy.functions import ModularIndexing, FloorDiv -import PyTorchSimFrontend.extension_codecache as extension_codecache - +from PyTorchSimFrontend import extension_codecache from PyTorchSimFrontend import extension_config from . import mlir_common from .mlir_common import LoopLevel, LoopNest @@ -1565,10 +1564,10 @@ def make_choices(self, nodes, kernel_name): current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) search_space.add(current_tile_sz) - print(f"[Auto-tune] Trying tile size: {current_tile_sz}, vlane_stride: {vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") + print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") self._prepare_simulator_headers(src_code) bench_runner = self.run_bench(nodes, kernel_name, src_code) - choices.append((bench_runner, src_code, self.kernel_group)) + choices.append((bench_runner, src_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride)) while prevent_infinite_loop < 10 and candidate_axes: for axis in list(candidate_axes): @@ -1593,6 +1592,13 @@ def make_choices(self, nodes, kernel_name): src_code = super().codegen_nodes(nodes, kernel_name) current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) + # FIXME. How to intergrate this constraint to tile system? + pad = self.kernel_group.tile_desc.vmap.get_used_vlane(current_tile_sz) * self.kernel_group.tile_desc.vmap.vlane_stride + vlane_size = current_tile_sz[self.kernel_group.tile_desc.vmap.vlane_split_axis] + if vlane_size > pad and vlane_size % pad: + prevent_infinite_loop += 1 + continue + # If tile size is converged for this axis, remove from candidate axes if current_tile_sz in search_space: candidate_axes.remove(axis) @@ -1600,26 +1606,25 @@ def make_choices(self, nodes, kernel_name): # Add this choice search_space.add(current_tile_sz) - print(f"[Auto-tune] Trying tile size: {current_tile_sz}, vlane_stride: {vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") + print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") self._prepare_simulator_headers(src_code) bench_runner = self.run_bench(nodes, kernel_name, src_code) - choices.append((bench_runner, src_code, self.kernel_group)) + choices.append((bench_runner, src_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride)) prevent_infinite_loop += 1 self.kernel_group.tile_desc.prev_tail_threshold = prev_tail_threshold return choices - def autotune(self, nodes, kernel_name): + def autotune(self, *args): def get_cycle(choice): - bench_runner, src_code, kernel_group = choice + bench_runner = choice[0] for n_try in range(extension_config.CONFIG_MAX_AUTOTUNE_TRY): # TODO: make simple try: - # bench_runner = self.run_bench(nodes, kernel_name, src_code) - out = bench_runner(validate=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, autotune=True) + out = bench_runner() return out[-1] except (extension_codecache.SpadOverflowError, RuntimeError) as e: return float("inf") return float("inf") # Exceeded maximum number of autotuning attempts - choices = self.make_choices(nodes, kernel_name) + choices = self.make_choices(*args) if len(choices) == 0: # can't autotune return None @@ -1628,21 +1633,25 @@ def get_cycle(choice): max_idx = results.index(min(results)) if min(results) == float("inf"): raise RuntimeError("Failed to find optimal tile size...") - print(f"[Auto-tune] Optimal tile size: {choices[max_idx][2].tile_desc.get_tile_size()}, vlane_stride: {choices[max_idx][2].tile_desc.vmap.vlane_stride}, cycles: {results[max_idx]}") + self._log_autotune_result(choices[max_idx], results[max_idx]) optimal_src_code = choices[max_idx][1] return optimal_src_code + def _log_autotune_result(self, best_choice, best_cycle): + print( + f"[Auto-tune] Optimal tile size: {list(best_choice[2])}, " + f"vlane_stride: {best_choice[3]}, " + f"cycles: {best_cycle}" + ) + def codegen_nodes(self, nodes, kernel_name): src_code = super().codegen_nodes(nodes, kernel_name) self._prepare_simulator_headers(src_code) - if not extension_config.CONFIG_AUTOTUNE or extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: - return src_code - else: + if extension_config.CONFIG_AUTOTUNE and not extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: optimal_src_code = self.autotune(nodes, kernel_name) - if optimal_src_code: + if optimal_src_code is not None: return optimal_src_code - else: - return src_code + return src_code def _prepare_simulator_headers(self, src_code): write_path = extension_codecache.get_write_path(src_code) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 67d5380f..2644f125 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -408,6 +408,7 @@ def select_vlane_axis(self): self.vmap.vlane_split_axis = best_vlane_split_axis def pad_vlane_tile(self): + # FIXME. this doesn't follow tile constraints... vlane_split_axis, vlane_stride, vector_lane = self.vmap.vlane_split_axis, self.vmap.vlane_stride, self.vmap.vector_lane used_vlane = min(math.ceil(self._tile_size[vlane_split_axis] / vlane_stride), vector_lane) padded_size = used_vlane * vlane_stride @@ -790,7 +791,9 @@ def run_bench(self, nodes, kernel_name, src_code): "vector_lane" : self.vector_lane, "spad_info": self.spad_info, "vlen" : self.vlen, - "arg_attributes" : arg_attributes + "arg_attributes" : arg_attributes, + "validate" : extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + "autotune" : True, }, source_code=src_code, ) diff --git a/PyTorchSimFrontend/mlir/mlir_conv_common.py b/PyTorchSimFrontend/mlir/mlir_conv_common.py new file mode 100644 index 00000000..52979d73 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_common.py @@ -0,0 +1,120 @@ +import os +import math +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from PyTorchSimFrontend import extension_config + +class MLIRConvCommonTemplate(MLIRTemplate): + WRAPPER_TEMPLATE = None + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [str(i) for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, + **kwargs): + raise NotImplementedError() + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + raise NotImplementedError() + + def extract_info(self, kernel, template_buffer_node, epilogue_nodes): + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + return X,W,Y,Bias,n_extra_node,BATCH,I_C,I_H,I_W,O_C,K_H,K_W,O_H,O_W,PADDING_H,PADDING_W,STRIDE_H,STRIDE_W + + def get_tile_candidates(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + return self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(self.WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index 6dd17576..26018a94 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -1,16 +1,10 @@ -import os -import math from sympy import Symbol, Number from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs -from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_conv_common import MLIRConvCommonTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common -from torch._inductor.codecache import get_hash from PyTorchSimFrontend import extension_config CONV_TEMPLATE = r""" @@ -104,7 +98,8 @@ } """ -WRAPPER_TEMPLATE = r""" +class MLIRConvMultiTileTemplate(MLIRConvCommonTemplate): + WRAPPER_TEMPLATE = r""" def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Padding input padded_shape = list(X.shape) @@ -130,63 +125,26 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: yield ({{KERNEL_NAME}}, ) {%- endif %} """ - -class MLIRConvMultiTileTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.stride = kwargs["stride"] - self.padding = kwargs["padding"] - self.dilation = kwargs["dilation"] - self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [str(i) for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ - + "_".join([str(i) for i in self.stride]) \ - + "_" + "_".join([str(i) for i in self.padding]) \ - + "_" + "_".join([str(i) for i in self.dilation]) - self.kernel_args = ['X', 'W', 'Bias', 'Y'] - - def get_padded_input_size(self, X): - input_padded = list(X.layout.size) - input_padded[2] += 2 * self.padding[0] - input_padded[3] += 2 * self.padding[1] - return math.prod(input_padded) + super().__init__(input_nodes, layout, input_reorder, **kwargs) def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): # Extract input arguments info - if template_buffer_node is not None: - self.output_node = template_buffer_node - self.kernel = kernel - self.epilogue_nodes = epilogue_nodes - - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - if epilogue_nodes is not None: - extra_node_rw = { - item.name for epilogue_node in epilogue_nodes - for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes - if item.name != Y.name - } - n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 - - BATCH, I_C, I_H, I_W = X.layout.size - O_C, _, K_H, K_W = W.layout.size - O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] - O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] - PADDING_H=self.padding[0] - PADDING_W=self.padding[1] - STRIDE_H=self.stride[0] - STRIDE_W=self.stride[1] + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + if tile_info is None: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + else: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = O_W if TILE_M > O_W else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] @@ -284,69 +242,13 @@ def render(self, return code def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 - SUB_TILE_K = TILE_K - - TOG_latency = O_W if TILE_M > O_W else TILE_M - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency - - def outer_func_render(self, kernel_name, input_args): - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) - options = dict( - kernel=self.kernel, - KERNEL_NAME=kernel_name, - FUNC_NAME=self.function_name + f"_{len(input_args)}", - INPUT=X, - WEIGHT=W, - BIAS=Bias, - OUTPUT=Y, - PADDING_H=self.padding[0], - PADDING_W=self.padding[1], - VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, - BACKENDSIM_EAGER_MODE=eager_mode, - input_reorder=self.input_reorder - ) - code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) - return code, self.function_name + f"_{len(input_args)}" - - def get_arg_attributes(self): - arg_attributes = [] - - X = self.input_nodes[0] - X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] - X_shape[0] += 2 * self.padding[0] - X_shape[1] += 2 * self.padding[1] - - def compute_stride(shape): - stride = [1] * len(shape) - for i in range(len(shape)-2, -1, -1): - stride[i] = stride[i+1] * shape[i+1] - return stride - - X_stride = compute_stride(X_shape) - arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) - - return arg_attributes - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) - self.hash_value = get_hash(code.strip()) \ No newline at end of file + tile_candidates = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + tile_candidates[idx] = TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index 8b1bf7c5..a2959b4d 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -1,17 +1,10 @@ -import os -import math from sympy import Symbol, Number from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs -from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_conv_common import MLIRConvCommonTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common -from torch._inductor.codecache import get_hash -from PyTorchSimFrontend import extension_config CONV_TEMPLATE = r""" // Single Batch Conv2D kernel @@ -105,7 +98,8 @@ } """ -WRAPPER_TEMPLATE = r""" +class MLIRConvSingleBatchTemplate(MLIRConvCommonTemplate): + WRAPPER_TEMPLATE = r""" def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Padding input padded_shape = list(X.shape) @@ -131,63 +125,26 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: yield ({{KERNEL_NAME}}, ) {%- endif %} """ - -class MLIRConvSingleBatchTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.stride = kwargs["stride"] - self.padding = kwargs["padding"] - self.dilation = kwargs["dilation"] - self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [str(i) for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ - + "_".join([str(i) for i in self.stride]) \ - + "_" + "_".join([str(i) for i in self.padding]) \ - + "_" + "_".join([str(i) for i in self.dilation]) - self.kernel_args = ['X', 'W', 'Bias', 'Y'] - - def get_padded_input_size(self, X): - input_padded = list(X.layout.size) - input_padded[2] += 2 * self.padding[0] - input_padded[3] += 2 * self.padding[1] - return math.prod(input_padded) + super().__init__(input_nodes, layout, input_reorder, **kwargs) def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): # Extract input arguments info - if template_buffer_node is not None: - self.output_node = template_buffer_node - self.kernel = kernel - self.epilogue_nodes = epilogue_nodes - - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - if epilogue_nodes is not None: - extra_node_rw = { - item.name for epilogue_node in epilogue_nodes - for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes - if item.name != Y.name - } - n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 - - BATCH, I_C, I_H, I_W = X.layout.size - O_C, _, K_H, K_W = W.layout.size - O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] - O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] - PADDING_H=self.padding[0] - PADDING_W=self.padding[1] - STRIDE_H=self.stride[0] - STRIDE_W=self.stride[1] + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + if tile_info is None: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + else: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = O_W if TILE_M > O_W else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] # Prepare tile descriptors @@ -283,66 +240,13 @@ def render(self, return code def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] - SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 - SUB_TILE_M = TILE_I_W if TILE_I_W < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - TOG_latency = O_W if TILE_M > O_W else TILE_M - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency - - def outer_func_render(self, kernel_name, input_args): - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) - options = dict( - kernel=self.kernel, - KERNEL_NAME=kernel_name, - FUNC_NAME=self.function_name + f"_{len(input_args)}", - INPUT=X, - WEIGHT=W, - BIAS=Bias, - OUTPUT=Y, - PADDING_H=self.padding[0], - PADDING_W=self.padding[1], - VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, - BACKENDSIM_EAGER_MODE=eager_mode, - input_reorder=self.input_reorder - ) - code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) - return code, self.function_name + f"_{len(input_args)}" - - def get_arg_attributes(self): - arg_attributes = [] - - X = self.input_nodes[0] - X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] - X_shape[0] += 2 * self.padding[0] - X_shape[1] += 2 * self.padding[1] - - def compute_stride(shape): - stride = [1] * len(shape) - for i in range(len(shape)-2, -1, -1): - stride[i] = stride[i+1] * shape[i+1] - return stride - - X_stride = compute_stride(X_shape) - arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) - - return arg_attributes - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) - self.hash_value = get_hash(code.strip()) \ No newline at end of file + tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_I_W if TILE_I_W < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + tile_candidates[idx] = TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index 2284c86c..afbe9289 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -1,17 +1,10 @@ -import os -import math from sympy import Symbol, Number from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs -from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_conv_common import MLIRConvCommonTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common -from torch._inductor.codecache import get_hash -from PyTorchSimFrontend import extension_config CONV_TEMPLATE = r""" // Single Batch Conv2D (Stride != 1) kernel @@ -105,7 +98,8 @@ } """ -WRAPPER_TEMPLATE = r""" +class MLIRConvSingleBatchStridedTemplate(MLIRConvCommonTemplate): + WRAPPER_TEMPLATE = r""" def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Padding input padded_shape = list(X.shape) @@ -131,63 +125,26 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: yield ({{KERNEL_NAME}}, ) {%- endif %} """ - -class MLIRConvSingleBatchStridedTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.stride = kwargs["stride"] - self.padding = kwargs["padding"] - self.dilation = kwargs["dilation"] - self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [str(i) for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ - + "_".join([str(i) for i in self.stride]) \ - + "_" + "_".join([str(i) for i in self.padding]) \ - + "_" + "_".join([str(i) for i in self.dilation]) - self.kernel_args = ['X', 'W', 'Bias', 'Y'] - - def get_padded_input_size(self, X): - input_padded = list(X.layout.size) - input_padded[2] += 2 * self.padding[0] - input_padded[3] += 2 * self.padding[1] - return math.prod(input_padded) + super().__init__(input_nodes, layout, input_reorder, **kwargs) def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): # Extract input arguments info - if template_buffer_node is not None: - self.output_node = template_buffer_node - self.kernel = kernel - self.epilogue_nodes = epilogue_nodes - - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - if epilogue_nodes is not None: - extra_node_rw = { - item.name for epilogue_node in epilogue_nodes - for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes - if item.name != Y.name - } - n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 - - BATCH, I_C, I_H, I_W = X.layout.size - O_C, _, K_H, K_W = W.layout.size - O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] - O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] - PADDING_H=self.padding[0] - PADDING_W=self.padding[1] - STRIDE_H=self.stride[0] - STRIDE_W=self.stride[1] + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + if tile_info is None: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + else: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = O_W if TILE_M > O_W else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] @@ -284,66 +241,13 @@ def render(self, return code def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] - SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - TOG_latency = O_W if TILE_M > O_W else TILE_M - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency - - def outer_func_render(self, kernel_name, input_args): - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) - options = dict( - kernel=self.kernel, - KERNEL_NAME=kernel_name, - FUNC_NAME=self.function_name + f"_{len(input_args)}", - INPUT=X, - WEIGHT=W, - BIAS=Bias, - OUTPUT=Y, - PADDING_H=self.padding[0], - PADDING_W=self.padding[1], - VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, - BACKENDSIM_EAGER_MODE=eager_mode, - input_reorder=self.input_reorder - ) - code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) - return code, self.function_name + f"_{len(input_args)}" - - def get_arg_attributes(self): - arg_attributes = [] - - X = self.input_nodes[0] - X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] - X_shape[0] += 2 * self.padding[0] - X_shape[1] += 2 * self.padding[1] - - def compute_stride(shape): - stride = [1] * len(shape) - for i in range(len(shape)-2, -1, -1): - stride[i] = stride[i+1] * shape[i+1] - return stride - - X_stride = compute_stride(X_shape) - arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) - - return arg_attributes - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) - self.hash_value = get_hash(code.strip()) \ No newline at end of file + tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + tile_candidates[idx] = TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 890b76b7..777d0a7b 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -1,17 +1,10 @@ -import os -import math from sympy import Symbol, Number from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs -from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_conv_common import MLIRConvCommonTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common -from torch._inductor.codecache import get_hash -from PyTorchSimFrontend import extension_config CONV_TEMPLATE = r""" // Conv2D kernel @@ -109,7 +102,8 @@ } """ -WRAPPER_TEMPLATE = r""" +class MLIRConvTemplate(MLIRConvCommonTemplate): + WRAPPER_TEMPLATE = r""" def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Padding input padded_shape = list(X.shape) @@ -135,63 +129,25 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: yield ({{KERNEL_NAME}}, ) {%- endif %} """ - -class MLIRConvTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.stride = kwargs["stride"] - self.padding = kwargs["padding"] - self.dilation = kwargs["dilation"] - self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [str(i) for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ - + "_".join([str(i) for i in self.stride]) \ - + "_" + "_".join([str(i) for i in self.padding]) \ - + "_" + "_".join([str(i) for i in self.dilation]) - self.kernel_args = ['X', 'W', 'Bias', 'Y'] - - def get_padded_input_size(self, X): - input_padded = list(X.layout.size) - input_padded[2] += 2 * self.padding[0] - input_padded[3] += 2 * self.padding[1] - return math.prod(input_padded) + super().__init__(input_nodes, layout, input_reorder, **kwargs) def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): # Extract input arguments info - if template_buffer_node is not None: - self.output_node = template_buffer_node - self.kernel = kernel - self.epilogue_nodes = epilogue_nodes - - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - if epilogue_nodes is not None: - extra_node_rw = { - item.name for epilogue_node in epilogue_nodes - for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes - if item.name != Y.name - } - n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 - - BATCH, I_C, I_H, I_W = X.layout.size - O_C, _, K_H, K_W = W.layout.size - O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] - O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] - PADDING_H=self.padding[0] - PADDING_W=self.padding[1] - STRIDE_H=self.stride[0] - STRIDE_W=self.stride[1] + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) - SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + if tile_info is None: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + else: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info + TOG_latency = BATCH if TILE_M > BATCH else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] @@ -289,68 +245,14 @@ def render(self, return code def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] - SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 - SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N - TOG_latency = BATCH if TILE_M > BATCH else TILE_M - TOG_latency = 8 if TOG_latency < 8 else TOG_latency - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency - - def outer_func_render(self, kernel_name, input_args): - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) - options = dict( - kernel=self.kernel, - KERNEL_NAME=kernel_name, - FUNC_NAME=self.function_name + f"_{len(input_args)}", - INPUT=X, - WEIGHT=W, - BIAS=Bias, - OUTPUT=Y, - PADDING_H=self.padding[0], - PADDING_W=self.padding[1], - VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, - BACKENDSIM_EAGER_MODE=eager_mode, - input_reorder=self.input_reorder - ) - code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) - return code, self.function_name + f"_{len(input_args)}" - - def get_arg_attributes(self): - arg_attributes = [] - - X = self.input_nodes[0] - X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] - X_shape[0] += 2 * self.padding[0] - X_shape[1] += 2 * self.padding[1] - - def compute_stride(shape): - stride = [1] * len(shape) - for i in range(len(shape)-2, -1, -1): - stride[i] = stride[i+1] * shape[i+1] - return stride - - X_stride = compute_stride(X_shape) - arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) - - return arg_attributes - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) - self.hash_value = get_hash(code.strip()) \ No newline at end of file + tile_candidates = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + tile_candidates[idx] = TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index ae793c06..c2120e7b 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -8,8 +8,6 @@ from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir import mlir_common @@ -114,30 +112,13 @@ def render(self, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): - if template_buffer_node is not None: - self.output_node = template_buffer_node - - # Extract input arguments info - X, W, Y = self.input_nodes[0], self.input_nodes[1], self.output_node - X_tensor = empty_strided(X.layout.size, X.layout.stride) - W_tensor = empty_strided(W.layout.size, W.layout.stride) - if len(W_tensor.size()) > 2 or len(X_tensor.size()) > 2: - raise NotImplementedError("Please report this case to us...") - - # Extract fusion info - n_epilogue_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 - n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 - n_extra_read = set() - if epilogue_nodes is not None: - for enode in epilogue_nodes: - n_extra_read.update(enode.node.get_read_names()) - if self.output_node.name in n_extra_read: - n_extra_read.remove(self.output_node.name) - - # Select tile size - M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + if tile_info is None: + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node)[0] + else: + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info # Select template code if (M == 0) or (N == 0) or (K == 0): # exception for MoE @@ -275,12 +256,47 @@ def render(self, dram_idx = Y_idx, dram_tile_desc = Y_tile_desc, nr_rdim = nr_rdim, + r_dim_size = M, dim_aliasing = epilogue_dim_aliasing ) code = self._template_from_string(template).render(**kernel.render_options) kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code + def get_tile_candidates(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + **kwargs): + X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + return self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + + def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): + if template_buffer_node is not None: + self.output_node = template_buffer_node + + # Extract input arguments info + X, W, Y = self.input_nodes[0], self.input_nodes[1], self.output_node + X_tensor = empty_strided(X.layout.size, X.layout.stride) + W_tensor = empty_strided(W.layout.size, W.layout.stride) + if len(W_tensor.size()) > 2 or len(X_tensor.size()) > 2: + raise NotImplementedError("Please report this case to us...") + + # Extract fusion info + n_epilogue_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 + n_extra_read = set() + if epilogue_nodes is not None: + for enode in epilogue_nodes: + n_extra_read.update(enode.node.get_read_names()) + if self.output_node.name in n_extra_read: + n_extra_read.remove(self.output_node.name) + + # Select tile size + M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] + return X,W,Y,M,N,K,n_epilogue_node,n_prologue_node,len(n_extra_read) + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): # Check cheat sheet cheatsheet_path = extension_config.CONFIG_GEMM_CHEATSHEET_PATH @@ -292,52 +308,49 @@ def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_no data = json.load(f) gemm_shape = f"{M}_{K}_{N}" - if gemm_shape in data: + if extension_config.CONFIG_MANUAL_TILE_SIZE: + # case 1: use manual tile size + TILE_M = extension_config.CONFIG_TILE_M + TILE_N = extension_config.CONFIG_TILE_N + TILE_K = extension_config.CONFIG_TILE_K + tile_candidates = [[TILE_M, TILE_N, TILE_K]] + elif gemm_shape in data: + # case 2: cached tile size tile_info = data[gemm_shape] TILE_M = tile_info["TILE_M"] TILE_N = tile_info["TILE_N"] TILE_K = tile_info["TILE_K"] - else: # case 2: use gemm_combination_mapping + tile_candidates = [[TILE_M, TILE_N, TILE_K]] + else: + # case 3: use gemm_combination_mapping min_tile = (n_extra_node + n_prologue_node) == 0 - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, max(len(n_extra_read)-2, 0), n_prologue_node, min_tile=True) - # case 3: use manual tile size - if extension_config.CONFIG_MANUAL_TILE_SIZE: - TILE_M = extension_config.CONFIG_TILE_M - TILE_N = extension_config.CONFIG_TILE_N - TILE_K = extension_config.CONFIG_TILE_K + tile_candidates = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), n_prologue_node, min_tile=True) # Edge case if (M == 0) or (N == 0) or (K == 0): TILE_M, TILE_N, TILE_K = 1, 1, 1 + tile_candidates = [[TILE_M, TILE_N, TILE_K]] - # Calculate Sub Tile Size for fine-grained DMA - if extension_config.CONFIG_SUBTILE: - # Case 1: adjust selective fine-grained DMA (SFG-DMA) - SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane - if (TILE_M == M and TILE_N == N and TILE_N <= 512): - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - else: # Avoid Row Conflict of weights + full_tile_candidates = [] + for idx, (TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + # Calculate Sub Tile Size for fine-grained DMA + if extension_config.CONFIG_SUBTILE: + # Case 1: adjust selective fine-grained DMA (SFG-DMA) + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane + if (TILE_M == M and TILE_N == N and TILE_N <= 512): + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + else: # Avoid Row Conflict of weights + SUB_TILE_N = TILE_N + SUB_TILE_K = TILE_K + # Case 2: use manual sub tile size (FG-DMA) + if extension_config.CONFIG_MANUAL_SUBTILE_SIZE: + SUB_TILE_M = extension_config.CONFIG_SUBTILE_M + SUB_TILE_N = extension_config.CONFIG_SUBTILE_N + SUB_TILE_K = extension_config.CONFIG_SUBTILE_K + # Case 3: None Subtile + else: + SUB_TILE_M = TILE_M SUB_TILE_N = TILE_N - SUB_TILE_K = TILE_K - # Case 2: use manual sub tile size (FG-DMA) - if extension_config.CONFIG_MANUAL_SUBTILE_SIZE: - SUB_TILE_M = extension_config.CONFIG_SUBTILE_M - SUB_TILE_N = extension_config.CONFIG_SUBTILE_N - SUB_TILE_K = extension_config.CONFIG_SUBTILE_K - # Case 3: None Subtile - else: - SUB_TILE_M = TILE_M - SUB_TILE_N = TILE_N - SUB_TILE_K = TILE_K - return TILE_M,TILE_N,TILE_K, SUB_TILE_M,SUB_TILE_N,SUB_TILE_K - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) + SUB_TILE_K = TILE_K + full_tile_candidates.append([TILE_M,TILE_N,TILE_K, SUB_TILE_M,SUB_TILE_N,SUB_TILE_K]) + return full_tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py index 6f605d56..2cca36b6 100644 --- a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py +++ b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py @@ -6,8 +6,6 @@ from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode from torch._inductor.ir import ReinterpretView -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common import sympy @@ -42,6 +40,7 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): if template_buffer_node is not None: self.output_node = template_buffer_node @@ -99,14 +98,3 @@ def render(self, code = self._template_from_string(TEMPLATE).render(**kernel.render_options) kernel.add_loop_info([X.get_numel()], [kernel.vector_lane, kernel.vector_lane]) return code - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 2bbdb41d..26b90401 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -4,8 +4,11 @@ from functools import reduce import operator from sympy import symbols, sympify, Symbol +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel +from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest from torch._inductor import config from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode, SchedulerNode, BaseSchedulerNode @@ -259,85 +262,6 @@ def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) return kernel_name - def codegen_template_code(self, kernel, render, template_node, prologue_nodes, epilogue_nodes): - with kernel: - _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - for node in [template_node, *prologue_nodes, *epilogue_nodes]: - node.mark_run() - # Partial codgen template nodes - partial_code = render() - - # Swap load/store functions - kernel.load = kernel.load_epilogue - kernel.store = kernel.store_epilogue - kernel.store_reduction = kernel.store_reduction_epilogue - kernel.reduction = kernel.reduction_epilogue - - # Codegen prologue nodes - if prologue_nodes: - # Flush created varaibles, since template fusion doen't share variable - with kernel.prologue_buffer_group.as_local(): - _, (group, reduction_group) = max( - [prologue_nodes[-1]], key=lambda x: int(x.is_reduction()) - ).group - prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) - kernel.kernel_group.set_tile_info(prologue_tile_desc) - vars, reduction_vars = kernel.set_ranges(group, reduction_group) - for node in prologue_nodes: - # Reuse created spad - read_list = sorted([i.name for i in node.read_writes.reads]) - candidate_found = False - # Why? There is a case that memdep.get_size() != data.get_size() - buf_dict = {} - buf_dict.update({val.name : val for val in V.graph.buffers}) - buf_dict.update(V.graph.graph_inputs) - for candidate_read in read_list: - if candidate_read in buf_dict and reduce(operator.mul, buf_dict[candidate_read].get_size(), 1) == node.node.get_numel(): - prologue_input_arg = candidate_read - candidate_found = True - break - assert(candidate_found) - assert(len(node.read_writes.writes)==1) - prologue_output_arg = list(node.read_writes.writes)[0].name - template_buf = self.kernel_group.args.input_buffers[prologue_output_arg] - target_buf = f"{template_buf}_buffer" # FIXME. How to pass spad buffer name? - - # To skip the dma code gen - kernel.buffer_names[prologue_input_arg] = target_buf - kernel.buffer_names[prologue_output_arg] = target_buf - - # Edge delete - kernel.kernel_group.args.input_buffers = { - (arg if buf != template_buf else prologue_input_arg): buf - for arg, buf in kernel.kernel_group.args.input_buffers.items() - } - node.codegen((vars, reduction_vars)) - - # Codegen epilogue nodes - tile_desc = kernel.set_tile_size(kernel.epilogue_info) - kernel.kernel_group.set_tile_info(tile_desc) - kernel.call_ranges = None - if epilogue_nodes: - with kernel.epilogue_buffer_group.as_local(): - _, (group, reduction_group) = max( - epilogue_nodes, key=lambda x: int(x.is_reduction()) - ).group - vars, reduction_vars = kernel.set_ranges(group, reduction_group) - for node in epilogue_nodes: - node.codegen((vars, reduction_vars)) - - with V.set_kernel_handler(kernel): - src_code = ( - partial_code - if isinstance(partial_code, str) - else partial_code.finalize() - ) - - # For consistency, white space could make wrong write_path - buffer = IndentedBuffer() - buffer.splice(src_code) - return buffer.getvalue() - def codegen_template(self, template_node, epilogue_nodes): # Handle prologue pattern prologue_nodes = [] @@ -350,24 +274,13 @@ def codegen_template(self, template_node, epilogue_nodes): epilogue_nodes = epilogue_nodes[i+1:] break - _, (numel, rnumel) = template_node.group + # Generate template code template_buffer = template_node.node - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) + kernel, tile_candidates, render = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - - src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) - wrapper = V.graph.wrapper_code - - if src_code in wrapper.src_to_kernel: # [CONV] check inner function is already defined - kernel_name = wrapper.src_to_kernel[src_code] - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_name=kernel_name) # update kernel name - src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) + src_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) with V.set_kernel_handler(kernel): - spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" - spad_section_end_symbol = f"int spad_section_end[0] __attribute__ ((section(\".spad\"), aligned({kernel.spad_info['spad_size']*kernel.vector_lane})));" - codegen_header(src_code, (kernel.header.getvalue()+spad_end_symbol+spad_section_end_symbol, kernel.gem5_header.getvalue())) - kernel.meta_kernel() kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, kernel.loop_size, origins={str(i) for i in template_node.node.origins}) self.define_function(kernel) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 4f75dd84..e6e9dd0c 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -6,6 +6,8 @@ import contextlib import math import sympy +from functools import reduce +import operator from collections import OrderedDict from typing import List, Optional @@ -18,14 +20,16 @@ from torch._inductor.autotune_process import TensorMeta from torch._inductor.virtualized import V, NullHandler, _ops as ops from torch._inductor.utils import IndentedBuffer +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, reduction_combine_vec, is_welford_reduction from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode from torch._inductor.codegen import common -from PyTorchSimFrontend.extension_config import CONFIG_TORCHSIM_DIR +from PyTorchSimFrontend.extension_config import CONFIG_TORCHSIM_DIR, CONFIG_AUTOTUNE_TEMPLATE_TOPK from . import mlir_common class IndentedBufferGroup: @@ -93,7 +97,8 @@ def __init__(self, kernel_group = None, outer_func_name=None, outer_func_render=None, - kernel_arg_attributes=None) -> None: + kernel_arg_attributes=None, + reason=None) -> None: super().__init__(kernel_group if kernel_group is not None else mlir_common.MLIRWrapperKenrelGroup()) self.kernel_name = kernel_name self.input_nodes = input_nodes @@ -125,6 +130,15 @@ def __init__(self, self.reduction_mean = [] # Dim info self.dim_aliasing = {} + self.reason = reason + + def reset(self, reason): + self.__init__( + self.kernel_name, self.input_nodes, + self.call_size, self.kernel_group, + self.outer_func_name, self.outer_func_render, + self.kernel_arg_attributes, reason + ) def add_loop_info(self, mat_size, tile_size): for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)): @@ -185,7 +199,8 @@ def gemmini_gemm_mapping(self, M, N, K): return inner_I, inner_J, inner_K - def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False): + def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False, is_conv=False): + tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 # double buffer @@ -249,52 +264,22 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p max_used_spad_size = used_spad_size maximize_i_j = tile_M * tile_N mapping = (tile_M, tile_N, tile_K) - return mapping - - def search_mapping_space(self, mapping, idx, increment, stride, dilation, n_extra_node=0): - if idx == 0 or idx == 1 or idx == 4 or idx == 5 or idx == 6: - raise NotImplementedError("Only O_H and O_W are supported for search_mapping_space") - spad_size_per_lane = self.spad_info["spad_size"] - spad_size = spad_size_per_lane * self.vector_lane - max_spad_size = spad_size // 2 # double buffer - max_spad_per_lane = spad_size_per_lane // 2 # double buffer - - mapping = list(mapping) - mapping[idx] += increment - k_h, k_w, o_h, o_w, M, N, K = mapping - i_h = 1 + (o_h - 1) * stride[0] + (k_h - 1) * dilation[0] - i_w = 1 + (o_w - 1) * stride[1] + (k_w - 1) * dilation[1] - weight_size = k_w * k_h * K * N - input_size = i_w * i_h * M * K - output_size = o_w * o_h * M * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision - weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) - input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) - output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision - if used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane: - mapping = (k_h, k_w, o_h, o_w, M, N, K) - else: - mapping[idx] -= increment - - return mapping + if check_spad_size: + tile_candidates.append((used_spad_size, (tile_M, tile_N, tile_K))) - def pseudo_auto_tune(self, mapping, stride, dilation, O_H, O_W, n_extra_node=0): - # pseudo auto-tune - if mapping[2] == 1 and not (O_H == 1): - mapping = self.search_mapping_space(mapping, 2, 1, stride, dilation, n_extra_node=n_extra_node) - if mapping[3] == 1 and not (O_W == 1): - mapping = self.search_mapping_space(mapping, 3, 1, stride, dilation, n_extra_node=n_extra_node) - return mapping + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + return tile_candidates def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 # double buffer max_spad_per_lane = spad_size_per_lane // 2 # double buffer max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False) + M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] max_k_h_w = 1 # maximize kernel size max_o_h_w = 1 # maximize output size K = min(K, self.vector_lane) @@ -312,27 +297,30 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision - if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and max_k_h_w <= k_h * k_w and max_o_h_w <= o_h * o_w: - max_used_spad_size = used_spad_size - max_k_h_w = k_h * k_w - max_o_h_w = o_h * o_w - mapping = (k_h, k_w, o_h, o_w, M, N, K) + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size: + tile_candidates.append((used_spad_size, (k_h, k_w, o_h, o_w, M, N, K))) + if max_used_spad_size < used_spad_size and max_k_h_w <= k_h * k_w and max_o_h_w <= o_h * o_w: + max_used_spad_size = used_spad_size + max_k_h_w = k_h * k_w + max_o_h_w = o_h * o_w + mapping = (k_h, k_w, o_h, o_w, M, N, K) if max_used_spad_size == 0: raise RuntimeError("Cannot find a valid mapping") - # FIXME: this should be implemented with auto-tuning - mapping = self.pseudo_auto_tune(mapping, stride, dilation, O_H, O_W, n_extra_node=n_extra_node) - - return mapping + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + return tile_candidates def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False) + M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] max_k_h_w = K_W for o_h in sympy.divisors(O_H): for o_w in sympy.divisors(O_W): @@ -347,22 +335,28 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision - if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and max_k_h_w <= k_h: - max_used_spad_size = used_spad_size - max_k_h_w = k_h - mapping = (k_h, K_W, o_h, o_w, M, N, K) + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size: + tile_candidates.append((used_spad_size, (k_h, K_W, o_h, o_w, M, N, K))) + if max_used_spad_size < used_spad_size and max_k_h_w <= k_h: + max_used_spad_size = used_spad_size + max_k_h_w = k_h + mapping = (k_h, K_W, o_h, o_w, M, N, K) if max_used_spad_size == 0: raise RuntimeError("Cannot find a valid mapping") - return mapping + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + return tile_candidates def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False) + M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] max_k_h_w = 1 for o_h in sympy.divisors(O_H): for k_h in sympy.divisors(K_H): @@ -377,13 +371,18 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * k_w, K) output_size_per_lane = self.get_spad_size_per_lane(M * o_h * (1 + n_extra_node), N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision - if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and max_k_h_w <= k_h * k_w: - max_used_spad_size = used_spad_size - max_k_h_w = k_h * k_w - mapping = (k_h, k_w, o_h, M, M, N, K) + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size: + tile_candidates.append((used_spad_size, (k_h, k_w, o_h, M, M, N, K))) + if max_used_spad_size < used_spad_size and max_k_h_w <= k_h * k_w: + max_used_spad_size = used_spad_size + max_k_h_w = k_h * k_w + mapping = (k_h, k_w, o_h, M, M, N, K) if max_used_spad_size == 0: raise RuntimeError("Cannot find a valid mapping") - return mapping + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + return tile_candidates def meta_kernel(self): wrapper = V.graph.wrapper_code @@ -407,6 +406,126 @@ def call_kernel(self, kernel_name): kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args, cuda=False) + def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): + with self as kernel: + _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() + for node in [template_node, *prologue_nodes, *epilogue_nodes]: + node.mark_run() + + # Partial codgen template nodes + partial_code = render(kwargs={**render.keywords['kwargs'], 'tile_info': tile_info}) + + # Swap load/store functions + kernel.load = kernel.load_epilogue + kernel.store = kernel.store_epilogue + kernel.store_reduction = kernel.store_reduction_epilogue + kernel.reduction = kernel.reduction_epilogue + + # Codegen prologue nodes + if prologue_nodes: + # Flush created varaibles, since template fusion doen't share variable + with kernel.prologue_buffer_group.as_local(): + _, (group, reduction_group) = max( + [prologue_nodes[-1]], key=lambda x: int(x.is_reduction()) + ).group + prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) + kernel.kernel_group.set_tile_info(prologue_tile_desc) + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + for node in prologue_nodes: + # Reuse created spad + read_list = sorted([i.name for i in node.read_writes.reads]) + candidate_found = False + # Why? There is a case that memdep.get_size() != data.get_size() + buf_dict = {} + buf_dict.update({val.name : val for val in V.graph.buffers}) + buf_dict.update(V.graph.graph_inputs) + for candidate_read in read_list: + if candidate_read in buf_dict and reduce(operator.mul, buf_dict[candidate_read].get_size(), 1) == node.node.get_numel(): + prologue_input_arg = candidate_read + candidate_found = True + break + assert(candidate_found) + assert(len(node.read_writes.writes)==1) + prologue_output_arg = list(node.read_writes.writes)[0].name + template_buf = self.kernel_group.args.input_buffers[prologue_output_arg] + target_buf = f"{template_buf}_buffer" # FIXME. How to pass spad buffer name? + + # To skip the dma code gen + kernel.buffer_names[prologue_input_arg] = target_buf + kernel.buffer_names[prologue_output_arg] = target_buf + + # Edge delete + kernel.kernel_group.args.input_buffers = { + (arg if buf != template_buf else prologue_input_arg): buf + for arg, buf in kernel.kernel_group.args.input_buffers.items() + } + node.codegen((vars, reduction_vars)) + + # Codegen epilogue nodes + tile_desc = kernel.set_tile_size(kernel.epilogue_info) + kernel.kernel_group.set_tile_info(tile_desc) + kernel.call_ranges = None + if epilogue_nodes: + with kernel.epilogue_buffer_group.as_local(): + _, (group, reduction_group) = max( + epilogue_nodes, key=lambda x: int(x.is_reduction()) + ).group + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + for node in epilogue_nodes: + node.codegen((vars, reduction_vars)) + + with V.set_kernel_handler(kernel): + src_code = ( + partial_code + if isinstance(partial_code, str) + else partial_code.finalize() + ) + + # For consistency, white space could make wrong write_path + buffer = IndentedBuffer() + buffer.splice(src_code) + src_code = buffer.getvalue() + self._prepare_simulator_headers(src_code) + return src_code + + def make_choices(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): + choices = [] + for tile_info in tile_candidates: + print(f"[Auto-tune] Trying tile size: {list(tile_info)}") + src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) + bench_runner = self.run_bench([template_node], self.kernel_name, src_code) + choices.append((bench_runner, src_code, tile_info)) + self.reset(reason=None) + return choices + + def _log_autotune_result(self, best_choice, best_cycle): + tile_size = best_choice[2] + print( + f"[Auto-tune] Optimal tile size: {list(tile_size)}, " + f"cycles: {best_cycle}" + ) + + def codegen_nodes(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): + src_code = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) + + with V.set_kernel_handler(self): + self.meta_kernel() + return src_code + + def _prepare_simulator_headers(self, src_code): + spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" + spad_section_end_symbol = f"int spad_section_end[0] __attribute__ ((section(\".spad\"), aligned({self.spad_info['spad_size']*self.vector_lane})));" + + write_path = extension_codecache.get_write_path(src_code) + if not os.path.exists(write_path): + os.makedirs(write_path, exist_ok=True) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, self.header.getvalue()+spad_end_symbol+spad_section_end_symbol) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, self.gem5_header.getvalue()) + def codegen_prologue_body(self): body = IndentedBuffer() with self.prologue_buffer_group.as_local(): @@ -773,7 +892,7 @@ def load_epilogue(self, name: str, index: sympy.Expr): vshape = f"vector<{vsize}x{mlir_dtype}>" if compute_vec_size > 1: - offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.reduction_axis_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})") + offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.r_tile_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})") compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"]) operation = "affine.vector_load" line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" @@ -958,12 +1077,7 @@ def store_reduction_epilogue(self, name, index, value): if self.welford_reduce_out is not None: # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2 - divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.reduction_axis_size)} : f32") - if self.reduction_axis_size - 1 > 0: - divider2 = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.reduction_axis_size-1)} : f32") - else: - divider2 = divider - + divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.r_dim_size)} : f32") if self.buffer_types[name][1] > 1: divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to {new_reduced_shape}") else: @@ -1002,15 +1116,16 @@ def set_tile_size(self, template_fusion_info, prologue=False): if 'nr_rdim' in template_fusion_info and template_fusion_info['nr_rdim']==1: tile_desc.nr_rdim = 1 numel_per_lane = tile_desc.get_numel_per_lane() - reduction_axis_size = tile_desc.get_tile_size()[-1] - nr_outer_loop = (numel_per_lane + reduction_axis_size-1) // reduction_axis_size + r_tile_size = tile_desc.get_tile_size()[-1] + nr_outer_loop = (numel_per_lane + r_tile_size-1) // r_tile_size tile_desc.vec_size = nr_outer_loop * 32 # Why? Emprically selected, other option failed to functionality... self.reduction_fusion = True - self.reduction_axis_size = tile_desc.get_tile_size()[-1] + self.r_tile_size = tile_desc.get_tile_size()[-1] + self.r_dim_size = template_fusion_info['r_dim_size'] self.reduction_nr_outer_loop = nr_outer_loop self.reduction_loop_idx = "reduce_loop_idx" - self.compute_body_loop.size = reduction_axis_size + self.compute_body_loop.size = r_tile_size self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: @@ -1110,7 +1225,8 @@ def make_kernel_render( template=self, kwargs=kwargs ) - return kernel, render, self.codegen_header + tile_candidates = self.get_tile_candidates(**kwargs)[:CONFIG_AUTOTUNE_TEMPLATE_TOPK] + return kernel, tile_candidates, render return MLIRTemplateCaller( kernel_hash_name, @@ -1122,5 +1238,8 @@ def make_kernel_render( self, ) + def get_tile_candidates(self, **kwargs): + return [] + def render(self, **kwargs) -> str: raise NotImplementedError \ No newline at end of file